Simple AF3 analysis
===================

**Note**: This notebook is made for Google Colab  
<a target="_blank" href="https://colab.research.google.com/github/tubiana/practicals_AI-biology-genetics/blob/main/Display_AF3_results.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

This notebook was made for the summer school AI & Biology and Genetics practical, to ease the analysis of AF3 outputs.  
Special thanks to Samuel Murail for the developpement of his [AlphaFold analysis tool](https://github.com/samuelmurail/af2_analysis)

In [None]:
#@title  Installation of  dependencies
import os

_ = os.system("pip install -q plotly pandas numpy matplotlib py3Dmol")
_ = os.system("pip install -q git+https://github.com/samuelmurail/af2_analysis.git@main")



In [None]:
#@markdown Paste your filebin URL of the zipfile (from `right-click -> copy link`)

#@markdown   example `https://filebin.net/3na756769l21pial/fold_polr2a_ercc3_tfiib.zip`
url = "" #@param {type:"string"}

# Import libraries
import os
import zipfile
from urllib.request import urlretrieve

# Download the zip file#https://filebin.net/3na756769l21pial/fold_polr2a_ercc3_tfiib.zip

!wget $url

filename = os.path.basename(url)
print(filename)

# Extract the zip file
datafolder = os.path.splitext(os.path.basename(filename))[0]
with zipfile.ZipFile(filename, 'r') as zip_ref:
    zip_ref.extractall(datafolder)

print("Data folder:", datafolder)


In [None]:
#@title Reading data

import json
import af_analysis

mymodels = af_analysis.Data(datafolder, verbose=False)
plddts = []
for i in range(len(mymodels.df)):
    plddts.append(mymodels.get_plddt(i))

In [None]:
#@title Interactive pLDDT plot

import plotly.graph_objects as go
import numpy as np
import matplotlib.pyplot as plt
import plotly.express as px

fig = go.Figure()

# Calculer les longueurs cumulées des chaînes
chain_lengths = mymodels.chain_length[mymodels.df.iloc[0]["query"]]
cumsum_values = np.cumsum(chain_lengths[:-1])

# Fonction pour obtenir la lettre de la chaîne
def get_chain_letter(index):
    return chr(65 + index)  # A, B, C, ...

for i in range(len(plddts)):
    x_values = list(range(len(plddts[i])))
    hover_text = []
    current_chain = 0
    residue_counter = 1

    for j, plddt in enumerate(plddts[i]):
        if current_chain < len(cumsum_values) and j >= cumsum_values[current_chain]:
            current_chain += 1
            residue_counter = 1

        chain_letter = get_chain_letter(current_chain)
        hover_text.append(f"<b>Protéine {chain_letter}, residue {residue_counter}</b><br>pLDDT={plddt:.2f}")
        residue_counter += 1

    fig.add_trace(go.Scatter(
        x=x_values,
        y=plddts[i],
        mode='lines',
        line=dict(width=0.5),
        name=f"Model {i}",
        text=hover_text,
        hoverinfo='text'
    ))

# Ajoutez les lignes verticales
for x_val in cumsum_values:
    fig.add_shape(
        type="line",
        x0=x_val,
        y0=0,
        x1=x_val,
        y1=1,
        line=dict(color="black", width=1),
        xref='x',
        yref='paper'
    )

fig.update_layout(
    title="Predicted LDDT for each model",
    xaxis_title="Residue",
    yaxis_title="Predicted LDDT",
)


fig.show()

In [None]:
#title Non interactive pLDDT plot

import matplotlib.pyplot as plt
chain_lengths = mymodels.chain_length[mymodels.df.iloc[0]["query"]]
cumsum_values = np.cumsum(chain_lengths[:-1])

plt.figure(figsize=(10, 6))

for i, values in enumerate(plddts):
    plt.plot(values, label=f"rank {i}",)

plt.title("pLDDT Confidence Scores")
plt.xlabel("Residue Index")
plt.ylabel("pLDDT Score")
plt.ylim(0, 100)
plt.legend()

for x_val in cumsum_values:
    plt.axvline(
        x=x_val,
        color='black',
        linestyle='--',
        linewidth=0.5,
        )

plt.show()

In [None]:
#@title PAE matrices
#Load all PAE and save them into a arrays called "paes"
paes = []
for jsonfile in mymodels.df["json"]:
    with open(jsonfile, "r") as f:
        jsondata = json.load(f)
        paes.append(np.array(jsondata["pae"]))



#Code from https://github.com/samuelmurail/af2_analysis/blob/main/src/af2_analysis/data.py
fig, axs = plt.subplots(1,5, figsize=(25,5))

cumsum_values = np.cumsum(mymodels.chain_length[mymodels.df.iloc[0]["query"]][:-1])


for i in range(len(paes)):
    query = mymodels.df.iloc[i]["query"]
    resmax = sum(mymodels.chain_length[query])
    axs[i].imshow(paes[i], cmap='bwr', vmin=0, vmax=30)
    axs[i].set_title(f'Model {i+1}')

    axs[i].vlines(
        cumsum_values,
        ymin=-0.5,
        ymax=resmax,
        colors="black",
    )

    axs[i].hlines(
        cumsum_values,
        xmin=-0.5,
        xmax=resmax,
        colors="black",
    )
    axs[i].set_xlim(-0.5, resmax - 0.5)
    axs[i].set_ylim(resmax - 0.5, -0.5)
    chain_pos = []
    len_sum = 0

    for longueur in mymodels.chain_length[query]:
       chain_pos.append(len_sum + longueur / 2)
       len_sum += longueur
    axs[i].set_yticks(chain_pos)
    axs[i].set_yticklabels(mymodels.chains[query])






In [None]:
#@title VISUALISATION OF THE MODELS (press run at each changes)
#Code adapted from https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/AlphaFold2.ipynb#scrollTo=KK7X9T44pWb7
import py3Dmol
import glob
import matplotlib.pyplot as plt

rank_num = 1 #@param ["0", "1", "2", "3", "4"] {type:"raw"}
color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}


def plot_plddt_legend(dpi=100):
  thresh = ['plDDT:','Very low (<50)','Low (60)','OK (70)','Confident (80)','Very high (>90)']
  plt.figure(figsize=(1,0.1),dpi=dpi)
  ########################################
  for c in ["#FFFFFF","#FF0000","#FFFF00","#00FF00","#00FFFF","#0000FF"]:
    plt.bar(0, 0, color=c)
  plt.legend(thresh, frameon=False,
             loc='center', ncol=6,
             handletextpad=1,
             columnspacing=1,
             markerscale=0.5,)
  plt.axis(False)
  return plt

pymol_color_list = ["#33ff33","#00ffff","#ff33cc","#ffff00","#ff9999","#e5e5e5","#7f7fff","#ff7f00",
                    "#7fff7f","#199999","#ff007f","#ffdd5e","#8c3f99","#b2b2b2","#007fff","#c4b200",
                    "#8cb266","#00bfbf","#b27f7f","#fcd1a5","#ff7f7f","#ffbfdd","#7fffff","#ffff7f",
                    "#00ff7f","#337fcc","#d8337f","#bfff3f","#ff7fff","#d8d8ff","#3fffbf","#b78c4c",
                    "#339933","#66b2b2","#ba8c84","#84bf00","#b24c66","#7f7f7f","#3f3fa5","#a5512b"]
from string import ascii_uppercase,ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)



def show_pdb(pdb_file, mymodels, show_sidechains=False, show_mainchains=False, color="lDDT"):
  model_name = f"rank_{rank_num}"
  view = py3Dmol.view(js='https://tubiana.me/files/3Dmol-min.js',)
  view.addModel(open(pdb_file,'r').read(),'cif')

  if color == "lDDT":
    view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':50,'max':90}}})
  elif color == "rainbow":
    view.setStyle({'cartoon': {'color':'spectrum'}})
  elif color == "chain":

    chains = len(list(mymodels.chains.values())[0])
    for n,chain,color in zip(range(chains),alphabet_list,pymol_color_list):
       view.setStyle({'chain':chain},{'cartoon': {'color':color}})

  if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                        {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                        {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
  if show_mainchains:
    BB = ['C','O','N','CA']
    view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

  view.zoomTo()
  return view



pdbfile = mymodels.df.iloc[int(rank_num)]["pdb"]

show_pdb(pdbfile, mymodels, show_sidechains, show_mainchains, color).show()
if color == "lDDT":
  plot_plddt_legend().show()


In [None]:
#@title Best model PAE, but bigger 😉

#@title PAE matrices
#Load all PAE and save them into a arrays called "paes"

#Code from https://github.com/samuelmurail/af2_analysis/blob/main/src/af2_analysis/data.py
fig, ax = plt.subplots(1,1, figsize=(20,20))
cumsum_values = np.cumsum(mymodels.chain_length[mymodels.df.iloc[0]["query"]][:-1])


i = 0
query = mymodels.df.iloc[i]["query"]
resmax = sum(mymodels.chain_length[query])
ax.imshow(paes[i], cmap='bwr', vmin=0, vmax=30)
ax.set_title(f'Model {i+1}')

_= ax.vlines(
    cumsum_values,
    ymin=-0.5,
    ymax=resmax,
    colors="black",
)

_= ax.hlines(
    cumsum_values,
    xmin=-0.5,
    xmax=resmax,
    colors="black",
)
_= ax.set_xlim(-0.5, resmax - 0.5)
_= ax.set_ylim(resmax - 0.5, -0.5)
chain_pos = []
len_sum = 0

for longueur in mymodels.chain_length[query]:
    chain_pos.append(len_sum + longueur / 2)
    len_sum += longueur
_=ax.set_yticks(chain_pos)
_=ax.set_yticklabels(mymodels.chains[query])
