In [None]:
import af2_analysis as af2


# Parameters

In [None]:
TITLE_REPPORT = "Modelling repport XXXX  vs. YYYY"
WORKDIR = 'WORKDIR_FOLDER'
CUTOFF_DISTANCE = 0.5
MINIMUM_CONTACTS = 5
SELECTION1 = "chainid 0"
SELECTION2 = "chainid 1 "
CUTOFF_CONTACTS_GRAPH = 20
GENERATE_ALL_PAE_SCRIPT = "~/CNRS2022/dev/AFToolkit/PAE/generate_all_pae.py"
PPTX_TEMPLATE = "~/CNRS2022/dev/AFToolkit/misc/template_repporting.pptx"

In [None]:
import os
os.chdir(WORKDIR)

folders = [ f.path for f in os.scandir("predictions/") if f.is_dir() ]
for folder in folders: 
    os.system(f"python {GENERATE_ALL_PAE_SCRIPT} -s N -f {folder}")



In [None]:

#Define specific selection for each models, other SELECTION1 and SELECTION2 will apply.
# Example : "ORF2-1-40_TKB1x2":["chainid 0", "chainid 1 or chainid 2"],
SELECTION_MDTRAJ = {

}

# Define specific selection for actif PTM for each models. Otherwise "A-B" will apply. Element must be in a list. 
# Example : "ORF2-1-40_TKB1x2":["A-B", "A-C"],
SELECTION_actifPTM = {

}  

#DEFINE LABELS For each models Here. if Empty on MOD1_MOD2, the label will be "MOD1 vs. MOD2"
# Example : "ORF2-1-40_TKB1x2": "$\mathrm{ORF2}_{1-40}$ vs. $\mathrm{TKB1}_{\mathrm{dimer}}$",
labels_models = {


}

#Labels order in the final graph.
# Example :
# "ORF2-1-40_TKB1x2",
# "ORF2cut-1-40_TKB1x2", 
labels_order = [


]

# Samuel's tool wrapper

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import json
import glob
import os
from tqdm.notebook import tqdm
from af_analysis import analysis
from af_analysis import docking


def plot_single_PAE(data, index, ax=None, cmap = 'bwr'):
    if ax == None:
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    
    json_file = data.df["json"][index]
    with open(json_file) as f:
        json_data = json.load(f)

    query = data.df.iloc[index]["query"]
    
    borders = data.chain_length[query]
    res_max = sum(borders)
    
    PAE = json_data["pae"]

    ax.imshow(PAE, cmap=cmap,
        vmin=0.0,
        interpolation='nearest',
        vmax=30.0,)

    ax.hlines(
            np.cumsum(borders[:-1]) - 0.5,
            xmin=-0.5,
            xmax=res_max,
            colors="black",
        )

    ax.vlines(
            np.cumsum(borders[:-1]) - 0.5,
            ymin=-0.5,
            ymax=res_max,
            colors="black",
        )

    ax.set_xlim(-0.5, res_max - 0.5)
    ax.set_ylim(res_max - 0.5, -0.5)

    modelNumber = data.df["model"][index]
    ax.set_title(f"Rank {modelNumber}")
    return ax
        
        
        
def save_all_PAE(data, save=True):

    fig, axes = plt.subplots(1, 5, figsize=(20, 4))
    for i in range(len(data.df)):
        plot_single_PAE(data, i, cmap='bwr', ax=axes[i])
        

    plt.tight_layout()

    if save == True:
        plt.savefig(f"{data.dir}/PAE.png", dpi=300)

    



def save_plddt(data, ax=None):
    if ax==None:
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
    for i in range(len(data.df)):

        query = data.df.iloc[i]["query"]
        
        borders = data.chain_length[query]
        res_max = sum(borders)

        plddt = data.get_plddt(i)
        ax.plot(plddt, label=f"Model {i}", linewidth=0.5)
        
        ax.vlines(
                np.cumsum(borders[:-1]) - 0.5,
                ymin=-0.5,
                ymax=100,
                colors="black",
            )
        
        ax.set_ylim(0,100)

        #add legend for every plot, on the side
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))

        plt.tight_layout()
        plt.savefig(f"{data.dir}/plddt.png", dpi=300)

import re

def extract_rank(text):
    # Use a regular expression to find the rank number in the format "_rank_XXX_"
    match = re.search(r'_rank_(\d+)_', text)
    if match:
        return int(match.group(1))  # Convert the extracted rank to an integer
    return None


def get_all_data(workdir):
    os.chdir(workdir)

    folders = [f for f in glob.glob("*") if os.path.isdir(f)]

    print(folders)

    errors=[]
    list_of_scores=[]
    models_list = {}

    for model in tqdm(folders):
        
        try:
            data = af2.Data(model+"/")
        except Exception as e:
            errors.append(model)
            print(f"Error occurred for model {model}: {str(e)}")
            continue
        print(model)
        
        try:
            analysis.pdockq(data, )
        except:
            print("pdockQ failed")
        try:
            analysis.pdockq2(data, )
        except:
            print("pdockq2 failed")

        try:
            analysis.mpdockq(data, )
        except:
            print("mpdockq failed")
            
        try:
            analysis.inter_chain_pae(data, )
        except:
            print("inter_chain_pae failed")

        try:
            analysis.LIS_matrix(data, )
        except:
            print("LIS_matrix failed")

        try:
            docking.LIS_pep(data, )
        except:
            print("LIS PEP failed")

        if "rank" not in data.df.columns:
            #add the rank from PDB name in the dataframe
            data.df["rank"] = data.df["pdb"].apply(extract_rank)

        def add_value(data, column, row, value):
            if column not in data.df.columns:
                data.df[column] = pd.NA
            data.df.at[row, column] = value
        
        for i in range(len(data.df)):
            
            json_file = data.df["data_file"][i]
            with open(json_file) as f:
                json_data = json.load(f)
            pae_array = json_data["pae"]

            pae_mean = np.mean(pae_array)
            plddt = data.get_plddt(i)
            plddt_mean = np.mean(plddt)
            per_chain_ptm = json_data["per_chain_ptm"]
            pairwise_actifptm = json_data["pairwise_actifptm"]
            actifptm = json_data["actifptm"]

            add_value(data, "pae_mean", i, pae_mean)
            add_value(data, "plddt_mean", i, plddt_mean)
            add_value(data, "per_chain_ptm", i, per_chain_ptm)
            add_value(data, "pairwise_actifptm", i, pairwise_actifptm)
            add_value(data, "actifptm", i, actifptm)

            # data.df.at[i, "pae_mean"] = pae_mean
            # data.df.at[i, "plddt_mean"] = plddt_mean
            # data.df.at[i, "per_chain_ptm"] = per_chain_ptm
            # data.df.at[i, "pairwise_actifptm"] = pairwise_actifptm
            # data.df.at[i, "actifptm"] = actifptm

            
        name = str(data.df.iloc[0]["query"])
        models_list[name] = data

        list_of_scores.append(data.df)
    
    results = pd.concat(list_of_scores)

    return results

def re_create_logfile(workdir, logs):
    os.chdir(workdir)

    folders = [f for f in glob.glob("*") if os.path.isdir(f)]


    for model in folders:
        if not os.path.exists(f"{model}/log.txt"):
            with open (f"{model}/log.txt", 'w') as logout:
                logfound = False
                for logfile in logs:
                    with open(logfile, 'r') as f:
                        loglines = f.readlines()
                    for line in loglines:
                        if model in line and "Query" in line:
                            logfound = True
                            logout.write("2025-01-08 15:48:02,483 Running colabfold 1.5.5 (00de5b40adeec5368906b9f754ccb4212d05c64d)\n")
                            logout.write("2025-01-08 15:48:06,356 Running on GPU\n")
                            logout.write("2025-01-08 15:48:06,847 Found 5 citations for tools or databases\n")
                            logout.write(line)
                        elif "Query" in line and logfound:
                            logfound = False
                            break
                        elif logfound:
                            logout.write(line)
                            

# Analysis

In [None]:
os.chdir(WORKDIR)
# Activate only if you have a single log (or multiple) and that you need to recreate one in each model folder (it's important for the model detection by af_analysis pipeline)
# 
# logs=[
#     f"{WORKDIR}/predictions/log_batch1.txt",
#     f"{WORKDIR}/predictions/log_batch2.txt",
#     f"{WORKDIR}/predictions/log_batch3.txt",
#     f"{WORKDIR}/predictions/log_batch4.txt",

# ]
# re_create_logfile(f"{WORKDIR}/predictions", logs)
# os.chdir(WORKDIR)


In [None]:
results = get_all_data(f"{WORKDIR}/predictions")
results["interLIS"] = results["LIS"].apply(lambda x: np.mean([x[0][1], x[1][0]]))
results["LIS_average"] = results.groupby("query")["interLIS"].transform("mean")
results["LIS_std"] = results.groupby("query")["interLIS"].transform("std")
results.to_excel("all_results.xlsx")
results.query("rank == 1").sort_values(by="LIS_pep_rec", ascending=False).to_excel("first_ranked_models.xlsx")
results.query("rank == 1").sort_values(by="LIS_pep_rec", ascending=False)

# Contact analysis

In [None]:
import mdtraj as md
from collections import defaultdict
import seaborn as sns


def get_residue_label(traj, atom_index, add=0, with_chain = False):
    resname = traj.topology.atom(atom_index).residue.name
    resid = traj.topology.atom(atom_index).residue.resSeq + add
    chain = chr(ord('A') + traj.topology.atom(atom_index).residue.chain.index)
    if with_chain:
        return f"{resname}-{resid}-{chain}"
    else:
        return f"{resname} {resid}"

def compute_contacts(pdbs, selection1, selection2, add1, add2, cutoff=0.4, with_chain=False):
    traj = md.load(pdbs, top=pdbs[0])

    chainA = traj.topology.select(selection1)
    chainB = traj.topology.select(selection2)

    chainAB = np.concatenate((chainA,chainB))

    contactsB = md.compute_neighbors(traj, cutoff, query_indices=chainA, haystack_indices=chainB)
    contactsA = md.compute_neighbors(traj, cutoff, query_indices=chainB, haystack_indices=chainA)


    contacts_linearB = [item for sublist in contactsB for item in sublist]
    contacts_linearA = [item for sublist in contactsA for item in sublist]
    contacts_atoms = np.unique(np.concatenate((contacts_linearB, contacts_linearA)))

    subtraj = traj.atom_slice(contacts_atoms)

    subtraj.save_pdb("subtraj.pdb")

    #get the number of chains 
    nchains = subtraj.top.n_chains



    chain1_atoms = subtraj.topology.select(selection1)
    chain2_atoms = subtraj.topology.select(selection2)
    #parwise combinations of 2 chains
    chain1_chain2 = np.array(np.meshgrid(chain1_atoms, chain2_atoms)).T.reshape(-1, 2)
    # all_distances = md.compute_contacts(subtraj, contacts='all')
    all_distances = md.compute_distances(subtraj, atom_pairs=chain1_chain2)

    nframes = all_distances.shape[0]
    npair = all_distances.shape[1]

    # Create a dictionary to store the shortest distances between residues for each frame
    shortest_distances = defaultdict(lambda: defaultdict(lambda: float('inf')))

    for i in range(nframes):
        for j in range(npair):
            dist = all_distances[i, j]
            pair = chain1_chain2[j]
            res1 = get_residue_label(subtraj, pair[0], add=add1, with_chain=with_chain)
            res2 = get_residue_label(subtraj, pair[1], add=add2, with_chain=with_chain)
            
            # Check if the current distance is shorter than the stored shortest distance for the current frame
            if dist < shortest_distances[i][(res1, res2)]:
                shortest_distances[i][(res1, res2)] = dist


    # Convert the dictionary to a DataFrame with one column per frame
    df_shortest_distances = pd.DataFrame(shortest_distances)

    # set multilevel index names to "res1" and "res2"
    df_shortest_distances.index.names = ["res1", "res2"]

    return df_shortest_distances

def plot_contacts(df, cutoff, outputname="output.png", minimum_contacts=3, xaxis_label="selection 1", yaxis_label="selection 2", title="", vmax=None):

    ncol = len(df.columns)
    def count_values_below_threshold(row, threshold=0.4):
        return (row < threshold).sum()
    count_table  = df.apply(lambda x: (x < cutoff).sum(), axis=1).unstack(fill_value=0)

    # Trier les colonnes
    numeric_part_columns = count_table.columns.to_series().str.extract('(\d+)').astype(int)
    sorted_columns = numeric_part_columns[0].argsort()
    count_table = count_table.iloc[:, sorted_columns]

    # Trier l'index
    numeric_part_index = count_table.index.to_series().str.extract('(\d+)').astype(int)
    sorted_index = numeric_part_index[0].argsort()
    count_table = count_table.iloc[sorted_index]

    #keep only the values where the number of contacts is > 4
    count_table = count_table[count_table >= minimum_contacts].dropna(how='all', axis=0).dropna(how='all', axis=1)
    

    fig, ax = plt.subplots(figsize=(12,10))
    if vmax == None:
        vmax = count_table.max().max()
    g = sns.heatmap(count_table, cmap="Blues", annot=False,xticklabels=True, yticklabels=True, ax=ax, vmin=0, vmax=vmax)
    
    #Set xticklabels to display EVERY LABELS
    g.set_xticks(np.arange(0.5, len(count_table.columns), 1))
    g.set_xlabel(xaxis_label)
    g.set_ylabel(yaxis_label)
    g.set_xticklabels(g.get_xticklabels(), rotation=90, horizontalalignment='center', fontsize=8)
    
    
    g.set_title(title)
    plt.tight_layout()
    g.figure.savefig(outputname)
    

    return count_table.replace(np.nan,0)


In [None]:
models_name = results["query"].unique()

os.chdir(WORKDIR+"/predictions")

output_countact_table = {}
for model in models_name:
    #specific cases
    if model in SELECTION_MDTRAJ:
        sel1 = SELECTION_MDTRAJ[model][0]
        sel2 = SELECTION_MDTRAJ[model][1]
    else: 
        sel1 = SELECTION1
        sel2 = SELECTION2

    xaxis = model.split("_")[1]
    yaxis = model.split("_")[0]

    xsplit = xaxis.split("-")
    ysplit = yaxis.split("-")

    if len(xsplit) > 0:
        try:
            decalX = int(xsplit[1])-1
        except:
            decalX = 0
    else:
        decalX = 0
    if len(ysplit) > 0:
        try:
            decalY = int(ysplit[1])-1
        except:
            decalY = 0
    else:
        decalY = 0

    pdbs= results.query("query == @model")["pdb"].values


    df_shortest_distances = compute_contacts(pdbs, sel1, sel2, decalY, decalX, CUTOFF_DISTANCE, with_chain=False)
    ct = plot_contacts(df_shortest_distances, 
              CUTOFF_DISTANCE,
              xaxis_label=xaxis,
              yaxis_label=yaxis, 
              minimum_contacts=MINIMUM_CONTACTS,
              title=f"Number of contacts among the 15 models (minimum contact = {MINIMUM_CONTACTS})",
              outputname=f"{model}/{model}_contacts.png",
              vmax=15)
    output_countact_table[model] = ct


# Additional graphics

In [None]:
#If no definition of specific 

print("Dictionnary for renaming models")

LEAVE_EMPTY = True
for model in results["query"].unique():
    if LEAVE_EMPTY:
        print(f'"{model}":"{model}"')
    else:
        print(f'"{model}":"",')


#Take default labels if no user input

if len(labels_order) == 0:
    labels_order = list(results["query"].unique())
labels_models = []
if len(labels_models) == 0:
    labels_models = {model:model.split("_")[0] + " vs " + model.split("_")[1] for model in labels_order}

print("list for label order")
labels_models

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
import pandas as pd
import numpy as np
import os

sns.set_theme(style="whitegrid", context="paper", font="sans-serif", font_scale=1.5)

os.chdir(WORKDIR)

if not os.path.exists("figures"):
    os.makedirs("figures")

############################################
# Stripplot of LIS Scores
############################################

best = results.query("rank == 1")
fig, ax = plt.subplots(figsize=(10,7))
# Map the query names to their new labels
results['query_label'] = results['query'].map(labels_models)
# Sort according to the custom order by creating a categorical type
results['query'] = pd.Categorical(results['query'], categories=labels_order, ordered=True)
results = results.sort_values('query')
g = sns.stripplot(x="query_label", y="interLIS", data=results, ax=ax)
plt.xticks(rotation=90)
plt.title("LIS scores for each query", fontsize=16, weight="bold")
plt.ylabel("LIS score", fontsize=14, weight="bold")
plt.xlabel("")
plt.tight_layout()
plt.savefig("figures/LISscore.png", dpi=300)

############################################
# Best models Plot LIST Score (not so interesting anymore)
############################################
fig, ax = plt.subplots(figsize=(10,7))
best['query_label'] = best['query'].map(labels_models)
best['query'] = pd.Categorical(best['query'], categories=labels_order, ordered=True)
best = best.sort_values('query')
g = sns.barplot(x="query_label", y="interLIS", data=best, ax=ax)
plt.xticks(rotation=90)
plt.title("Best LIS scores for each query", fontsize=16, weight="bold")
plt.ylabel("LIS score", fontsize=14, weight="bold")
plt.xlabel("")
plt.tight_layout()
plt.savefig("figures/best_LISscore.png", dpi=300)

############################################
# Number of contact amon all models plot
############################################
contact_sums = [x.sum(axis=1) for x in output_countact_table.values()]
number_of_contact = pd.concat(contact_sums, axis=1).sum(axis=1).sort_values(ascending=False)

sns.set_style("whitegrid")
fig, ax = plt.subplots(figsize=(8.1, 10))
dataplot = number_of_contact[number_of_contact > CUTOFF_CONTACTS_GRAPH]
g = sns.barplot(y=dataplot.index, x=dataplot.values, ax=ax)
plt.xticks(rotation=90)
plt.title(f"Most interacting ORF2P residues (number of contacts minimum = {CUTOFF_CONTACTS_GRAPH})", fontsize=16, weight="bold")
plt.xlabel(f"Number of contacts among all models ", fontsize=14, weight="bold")
plt.ylabel("")
plt.tight_layout()
plt.savefig("figures/AA_best_interacting_residues.png", dpi=300)

############################################
#                ACTIF PTM PLOT
############################################
if len(SELECTION_actifPTM) == 0:
    SELECTION_actifPTM = {model: ["A-B"] for model in results["query"].unique()}
models = results["query"].unique()
actifPTM = defaultdict(list)
print(len(actifPTM))
actifptm_values = []
for model in models: 
    data_model = results.query("query == @model")
    for i in range(len(data_model)):
        actifPTM_pair = SELECTION_actifPTM[model]
        actifptm_values = []
        for pair in actifPTM_pair:
            actifptm_values.append(data_model.iloc[i]["pairwise_actifptm"][pair])
            actifPTM[model].append(np.mean(actifptm_values))

data = pd.Series(actifPTM)
# Sort according to the labels_order
data = data[labels_order]

data_long = pd.DataFrame([
    {"Model": key, "Model_label": labels_models[key], "actifPTM": value} 
    for key, values in data.items() 
    for value in values
])

# Create categorical type for proper ordering
data_long['Model'] = pd.Categorical(data_long['Model'], categories=labels_order, ordered=True)
data_long = data_long.sort_values('Model')


fig, ax = plt.subplots(figsize=(10,7))

# First plot the violin
sns.violinplot(x="Model_label", y="actifPTM", data=data_long, ax=ax, 
               color='lightgray', # Light color for the violin
               inner=None)  # No inner box plot

# Then add the points on top
sns.stripplot(x="Model_label", y="actifPTM", data=data_long, ax=ax, 
              jitter=True, size=3, 
              color='black',  # Dark points for contrast
              alpha=1)  # Some transparency for overlapping points

plt.xticks(rotation=90)
plt.title("actifPTM scores for each models", fontsize=16, weight="bold")
plt.ylabel("actifPTM", fontsize=18, weight="bold")
plt.ylim(0,1)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.xticks(rotation=90, fontsize=16, ha="center")
plt.yticks(fontsize=12)
plt.xlabel("")
plt.tight_layout()
plt.savefig("figures/actifPTM.png", dpi=300)



# Generation of powerpoint repport automatically

In [None]:
import json
def get_models_parameters(jsonfile):
    with open(jsonfile) as f:
        data = json.load(f)

    return data

models_parameters = get_models_parameters(f"{WORKDIR}/predictions/config.json")

In [None]:
import re
from pptx import Presentation
from pptx.util import Inches, Pt  # For size adjustment



os.chdir(WORKDIR)
# Create a PowerPoint presentation from a template
template_path = PPTX_TEMPLATE  # Specify your template path
prs = Presentation(template_path)

# Or if you want to set the aspect ratio to 16:9 manually, you can create a blank presentation and set the size
# prs = Presentation()
# prs.slide_width = Inches(13.33)  # 16:9 ratio width
# prs.slide_height = Inches(7.5)   # 16:9 ratio height

#Add a title slide


def add_title_slide(prs, title_input, subtitle_text):
    slide_layout = prs.slide_layouts[0]  # Title layout
    slide = prs.slides.add_slide(slide_layout)
    title = slide.shapes.title
    subtitle = slide.placeholders[1]

    title.text = title_input
    subtitle.text = subtitle_text



def add_methodology_slide(prs, models_parameters):
    # Create a new slide with a title and content layout
    slide_layout = prs.slide_layouts[1]  # Assuming this layout fits your template
    slide = prs.slides.add_slide(slide_layout)
    title = slide.shapes.title
    content = slide.shapes.placeholders[1]


    # Set the title
    title.text = "Methodology"
    title.text_frame.paragraphs[0].font.size = Pt(36)  # Set the font size for the title

    # Add bullet points with adjustable font size
    bullet_points = [
        f'ColabFold Version: {models_parameters["version"]} ({models_parameters["commit"]})',
        f'AlphaFold Model: {models_parameters["model_type"]}',
        f'MSA mode": {models_parameters["msa_mode"]}',
        f'Number of Queries: {models_parameters["num_queries"]}',
        f'Number of Models per Query: {models_parameters["num_models"]*models_parameters["num_seeds"]}',
        
    ]

    for point in bullet_points:
        p = content.text_frame.add_paragraph()
        p.text = point
        p.font.size = Pt(20)  # Adjust the font size for bullet points




def generate_pymol_figure(pdbfile, pymolpath="/home/thibault/miniconda3/bin/pymol"):
    pdbfile = os.path.basename(pdbfile)
    pymolcmd = f'''as cartoon
orient
spectrum b, rainbow_rev, minimum=10, maximum=90

alias raysetting, set ambient, 0.5; set specular, 0; set ray_trace_mode, 1; set ray_trace_gain, 0.01; set antialias,2; set ray_trace_color, black
raysetting

scene plddt, store
png model1_plddt.png, ray=1, width=1080

util.cbc
scene bychain, store
png model1_bychain.png, ray=1, width=1080
'''
    os.system(f"{pymolpath} -c {pdbfile} -d '{pymolcmd}'")

    pymol_allfigures_cmd = f'''
models = cmd.get_object_list()
cmd.alignto(models[0], "super")

orient
util.cbc
bgwhite
draw 2160
png all_models_bychain.png, width=1080
'''
    os.system(f"{pymolpath} -c *.pdb -d '{pymol_allfigures_cmd}'")




def add_section(prs, title):
    slide_layout = prs.slide_layouts[6]  # Title layout
    slide = prs.slides.add_slide(slide_layout)
    slide.shapes.title.text=title

def add_notes(slide, notes):
    slide.notes_slide.notes_text_frame.text = notes

def add_models_result_slide(prs, model_name, results_dataframe):
    ## Check all placeholders in the current slide
    #for placeholder in slide.placeholders:
    #    print(f"Placeholder index: {placeholder.placeholder_format.idx}, Placeholder type: {placeholder.placeholder_format.type}")  

    slide_layout = prs.slide_layouts[2]
    slide = prs.slides.add_slide(slide_layout)
    slide.shapes.title.text = f"Model {model_name}"

    #Add a picture of the PAE
    slide.placeholders[10].insert_picture("all_PAE.png")
    slide.placeholders[11].insert_picture(f"{model}_coverage.png")
    slide.placeholders[12].insert_picture(f"{model}_plddt.png")

    #Generating pymol figure, if not exist
    img_model_plddt = "model1_plddt.png"
    img_model_bychain = "model1_bychain.png"
    # Get unique values of the column "query"
    if not os.path.exists(img_model_plddt) or not os.path.exists(img_model_bychain) or not os.path.exists("all_models_bychain.png"):
        try:
            pdbfile = results_dataframe.query("query == @model_name and rank == 1")["pdb"].values[0]
        except:
            print(results_dataframe.query("query == @model_name"))
            1/0
        generate_pymol_figure(pdbfile)

    slide.placeholders[13].insert_picture(img_model_plddt)
    slide.placeholders[14].insert_picture(img_model_bychain)

    note = """DEFINITIONS OF THE METRICS 
- The pLDDT Score is a confidence score for each residue in the model. It ranges from 0 to 100, where higher scores indicate higher confidence. 
- The coverage plot shows the number of sequence in the multiple sequence aligment for each amino acid (Higher the better for the modelling). The colors depends sequence idendity. 
- PAE stands for "Predicted Aligned Error". This metric is a measure of the deviation between the predicted model and the experimental structure. Lower values are better. Low values coldspots between residues in chains A and B could indicate potential interaction sites.
"""

    add_notes(slide, note)

def add_all_models(prs, model_name):
    slide_layout = prs.slide_layouts[5]  # Title layout
    slide = prs.slides.add_slide(slide_layout)
    slide.shapes.title.text = f"Model {model_name} - All models"

    imgfile = "all_models_bychain.png"

    slide.placeholders[13].insert_picture(imgfile)






def add_slide_interactions(prs):
    slide_layout = prs.slide_layouts[3]  # Title layout
    slide = prs.slides.add_slide(slide_layout)
    slide.shapes.title.text="Interactions between MSN and other targets"

    figure_LIS = "figures/actifPTM.png"
    figure_bestLIS = "figures/LISscore.png"
    figure_bestAA = "figures/AA_best_interacting_residues.png"

    from pptx.util import Inches
    slide.placeholders[11].insert_picture(figure_bestAA, )
    slide.placeholders[12].insert_picture(figure_bestLIS)
    slide.placeholders[13].insert_picture(figure_LIS)
        

def add_contact_maps_slide(prs, model):
    prot = model.split("_")[0]
    target = model.split("_")[1]
    
    slide_layout = prs.slide_layouts[4]  # Title layout
    slide = prs.slides.add_slide(slide_layout)
    slide.shapes.title.text=f"Contacts between {prot} and {target}"
    slide.placeholders[11].insert_picture(f"{model}_contacts.png")
    note = f""" DEFINITION OF THE CONTACTS : 
a contact is define as one amino acid beeing at a distance of less than {CUTOFF_DISTANCE} nm from another amino acid.
Contacts between 2 amino acids are counted only one time per model.

Higher value means more contacts between the 2 amino acids
"""
    add_notes(slide, note)


#----------------------

add_title_slide(prs, TITLE_REPPORT, "AlphaFold repport modelling")

add_section(prs,"Results per models")

for model in models_name:
    os.chdir("predictions/"+model)
    add_models_result_slide(prs, model, results)
    add_contact_maps_slide(prs, model)
    add_all_models(prs, model)
    os.chdir(WORKDIR)


add_section(prs,"Comparison between models")
add_slide_interactions(prs)

add_section(prs,"Models Parameters")
add_methodology_slide(prs, models_parameters)

add_section(prs,"Conclusion")









# Instert 




# Save the PowerPoint
output_file = f'{WORKDIR}/repport.pptx'
prs.save(output_file)
print(f"PowerPoint saved to {output_file}")
os.chdir(WORKDIR)
