<a href="https://colab.research.google.com/github/mitiau/PROSTATA/blob/HSE_seminar/PROSTATA_tool.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Install dependecies and download weights

In [None]:
!pip install transformers
!pip install fair-esm
!pip install biopython
!pip install gdown==4.5.4

In [None]:
from google.colab import drive, files

import torch
from torch.utils.data import Dataset
from torch import nn

import transformers
from transformers.modeling_outputs import SequenceClassifierOutput

import pandas as pd
import numpy as np
import random

import esm
from esm import ProteinBertModel
from esm.pretrained import load_model_and_alphabet_hub

from Bio import SeqIO
from io import StringIO, BytesIO
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutationPosConcat
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutationPosOuter
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_cls
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_pos
!wget https://a025generative-modeling-for-design.obs.ru-moscow-1.hc.sbercloud.ru/hse_protein_seminar/ESMForSingleMutation_pos_cat_cls

In [None]:
!git clone https://github.com/mitiau/PROSTATA.git
!git -C PROSTATA checkout HSE_seminar
!git -C PROSTATA pull

In [None]:
%%time
#@title install ESMfold
#@markdown install ESMFold, OpenFold and download Params (~2min 30s)

import os, time
if not os.path.isfile("esmfold.model"):
  # download esmfold params
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://colabfold.steineggerlab.workers.dev/esm/esmfold.model &")

  # install libs
  os.system("pip install -q omegaconf pytorch_lightning biopython ml_collections einops py3Dmol")
  os.system("pip install -q git+https://github.com/NVIDIA/dllogger.git")

  # install openfold
  commit = "6908936b68ae89f67755240e2f588c09ec31d4c8"
  os.system(f"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}")

  # install esmfold
  os.system(f"pip install -q git+https://github.com/sokrypton/esm.git")

  # wait for Params to finish downloading...
  if not os.path.isfile("esmfold.model"):
    # backup source!
    os.system("aria2c -q -x 16 https://files.ipd.uw.edu/pub/esmfold/esmfold.model")
  else:
    while os.path.isfile("esmfold.model.aria2"):
      time.sleep(5)

In [None]:
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5
class ESMForSingleMutationPosConcat(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.fc1 = nn.Linear(1280 * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits
    
HIDDEN_UNITS_POS_OUTER = 5
class ESMForSingleMutationPosOuter(nn.Module):

    def __init__(self):
        super().__init__()
        self.esm2, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self._freeze_esm2_layers()
        self.fc1 = nn.Linear(1280 * 1280, HIDDEN_UNITS_POS_OUTER)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_OUTER, 1)

    def _freeze_esm2_layers(self):
        total_blocks = 33
        initial_layers = 2
        layers_per_block = 16
        num_freeze_blocks = total_blocks - 3
        for _, param in list(self.esm2.named_parameters())[
            :initial_layers + layers_per_block * num_freeze_blocks]:
            param.requires_grad = False

    def forward(self, token_ids1, token_ids2, pos):
        outputs1 = self.esm2.forward(token_ids1, repr_layers=[33])[
            'representations'][33]
        outputs2 = self.esm2.forward(token_ids2, repr_layers=[33])[
            'representations'][33]
        outputs1_pos = outputs1[:, pos + 1]
        outputs2_pos = outputs2[:, pos + 1]
        outer_prod = outputs1_pos.unsqueeze(3) @ outputs2_pos.unsqueeze(2)
        outer_prod_view = outer_prod.view(outer_prod.shape[0], outer_prod.shape[1], -1)
        fc1_outputs = F.relu(self.fc1(outer_prod_view))
        logits = self.fc2(fc1_outputs)
        return logits
    
class ESMForSingleMutation_pos(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,pos + 1,:] + self.const2 * outputs2[:,pos + 1,:]        
        logits = self.classifier(outputs)
        return logits
    
class ESMForSingleMutation_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        outputs = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]        
        logits = self.classifier(outputs.unsqueeze(0))
        return logits
    
class ESMForSingleMutation_pos_cat_cls(nn.Module):

    def __init__(self):
        super().__init__() 
        self.esm1v, self.esm1v_alphabet = esm.pretrained.esm2_t33_650M_UR50D()        
        self.classifier = nn.Linear(1280*2, 1)
        self.const1 = torch.nn.Parameter(torch.ones((1,1280)))
        self.const2 = torch.nn.Parameter(-1 * torch.ones((1,1280)))
        

    def forward(self, token_ids1, token_ids2, pos):                
        outputs1 = self.esm1v.forward(token_ids1, repr_layers=[33])['representations'][33]
        outputs2 = self.esm1v.forward(token_ids2, repr_layers=[33])['representations'][33]
        cls_out = self.const1 * outputs1[:,0,:] + self.const2 * outputs2[:,0,:]
        pos_out = self.const1 * outputs1[:,pos+1,:] + self.const2 * outputs2[:,pos+1,:]
        outputs = torch.cat([cls_out.unsqueeze(0), pos_out], axis = -1)        
        logits = self.classifier(outputs)
        return logits
    

In [None]:
model_names = ['ESMForSingleMutationPosOuter',
          'ESMForSingleMutationPosConcat',
          'ESMForSingleMutation_pos_cat_cls',  
              'ESMForSingleMutation_pos', 
              'ESMForSingleMutation_cls']

# Compute DeltaDDG for test set and compare with experimental data

In [None]:
model = torch.load('ESMForSingleMutation_cls', map_location=torch.device('cpu'))
esm2_alphabet = model.esm1v_alphabet
esm2batch_converter = esm2_alphabet.get_batch_converter()

def predict_ddg(seqs, mutation_codes, poss = None):
    if poss is None:
        poss = [None]*len(seqs)
    inp = []
    for seq, mutation_code, pos in zip(seqs, mutation_codes, poss):
        #print(mutation_code)
        wt_aa = mutation_code[0]
        mut_aa = mutation_code[-1]
        if pos:
            mut_pos = pos
        else:
            mut_pos = int(mutation_code[1:-1])-1

        assert seq[mut_pos] == wt_aa
        
        wt = seq
        tt = list(seq)
        tt[mut_pos] = mut_aa
        mut = ''.join(tt)

    
    
        _, _, esm2_batch_tokens1 = esm2batch_converter([('' , wt[:1022])])
        _, _, esm2_batch_tokens2 = esm2batch_converter([('' , mut[:1022])])
        esm2_batch_tokens1 = esm2_batch_tokens1.cuda()
        esm2_batch_tokens2 = esm2_batch_tokens2.cuda()
    
        inp.append((esm2_batch_tokens1, esm2_batch_tokens2, mut_pos))
    
    res = []
    for model_name in model_names:
        model = torch.load(model_name, map_location=torch.device('cpu'))
        model.eval()
        model.cuda()
        
        with torch.no_grad():
            res.append([model(token_ids1 = t1, token_ids2 = t2, 
                             pos = torch.LongTensor([p])).cpu().numpy() for t1, t2, p in inp])
        #print(f'Model {model_name} DDG prediction is {res[-1]}')
    res = np.mean(res, axis = 0)
    return res.ravel()
    

In [None]:
test_df = pd.read_csv('PROSTATA/cross_validation_datasets/test_1LNIA.csv')
test_df['ddg_pred'] = predict_ddg(test_df['wt_seq'].tolist(), 
                                  test_df['mut_info'].tolist(), 
                                  test_df['pos'].tolist())

In [None]:
y = test_df.ddg.to_list()
x = test_df.ddg_pred.to_list()
plt.scatter(x, y,alpha=0.5)
plt.show()

In [None]:
seqs = ['VINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR']
mutation_codes = ['V1N'] #@param {type:"string"}

In [None]:
predict_ddg(seqs, mutation_codes)

# Find best mutation

In [None]:
wildtype = 'VINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR'
wt_aas = list(set(wildtype))

In [None]:
# TODO: изменить на перебор мутаций

pos = 1
mut_acid = 'A'
mutation_code = f'{wildtype[pos -1]}{pos}{mut_acid}'
predict_ddg([wildtype], [mutation_code])

# Visualisation

In [None]:
#@title ##predict 3d structure with **ESMFold**
%%time
from string import ascii_uppercase, ascii_lowercase
import hashlib, re, os
import numpy as np
from jax.tree_util import tree_map
import matplotlib.pyplot as plt
from scipy.special import softmax

def parse_output(output):
  pae = (output["aligned_confidence_probs"][0] * np.arange(64)).mean(-1) * 31
  plddt = output["plddt"][0,:,1]

  bins = np.append(0,np.linspace(2.3125,21.6875,63))
  sm_contacts = softmax(output["distogram_logits"],-1)[0]
  sm_contacts = sm_contacts[...,bins<8].sum(-1)
  xyz = output["positions"][-1,0,:,1]
  mask = output["atom37_atom_exists"][0,:,1] == 1
  o = {"pae":pae[mask,:][:,mask],
       "plddt":plddt[mask],
       "sm_contacts":sm_contacts[mask,:][:,mask],
       "xyz":xyz[mask]}
  return o

def get_hash(x): return hashlib.sha1(x.encode()).hexdigest()
alphabet_list = list(ascii_uppercase+ascii_lowercase)

# jobname = "test" #@param {type:"string"}
jobname = 'test'
# jobname = re.sub(r'\W+', '', jobname)[:50]

sequence = "VINTFDGVADYLQTYHKLPDNYITKSEAQALGWVASKGNLADVAPGKSIGGDIFSNREGKLPGKSGRTWREADINYTSGFRNSDRILYSSDWLIYKTTDHYQTFTKIR" #@param {type:"string"}
sequence = re.sub("[^A-Z:]", "", sequence.replace("/",":").upper())
sequence = re.sub(":+",":",sequence)
sequence = re.sub("^[:]+","",sequence)
sequence = re.sub("[:]+$","",sequence)
copies = 1 #@param {type:"integer"}
if copies == "" or copies <= 0: copies = 1
sequence = ":".join([sequence] * copies)
num_recycles = 3 #@param ["0", "1", "2", "3", "6", "12", "24"] {type:"raw"}
chain_linker = 25

ID = jobname+"_"+get_hash(sequence)[:5]
seqs = sequence.split(":")
lengths = [len(s) for s in seqs]
length = sum(lengths)
print("length",length)

u_seqs = list(set(seqs))
if len(seqs) == 1: mode = "mono"
elif len(u_seqs) == 1: mode = "homo"
else: mode = "hetero"

if "model_f" not in dir():
  import torch
  model_f = torch.load("esmfold.model")
  model_f.eval().cuda().requires_grad_(False)

# optimized for Tesla T4
if length > 700:
  model_f.set_chunk_size(64)
else:
  model_f.set_chunk_size(128)

torch.cuda.empty_cache()
output = model_f.infer(sequence,
                     num_recycles=num_recycles,
                     chain_linker="X"*chain_linker,
                     residue_index_offset=512)

pdb_str = model_f.output_to_pdb(output)[0]
output = tree_map(lambda x: x.cpu().numpy(), output)
ptm = output["ptm"][0]
plddt = output["plddt"][0,...,1].mean()
O = parse_output(output)
print(f'ptm: {ptm:.3f} plddt: {plddt:.3f}')
os.system(f"mkdir -p {ID}")
prefix = f"{ID}/ptm{ptm:.3f}_r{num_recycles}_default"
np.savetxt(f"{prefix}.pae.txt",O["pae"],"%.3f")
with open(f"{prefix}.pdb","w") as out:
  out.write(pdb_str)

In [None]:
import py3Dmol
import gc
gc.collect()


In [None]:
#@title display mutation {run: "auto"}
view = py3Dmol.view(width=400, height=300)
view.addModelsAsFrames(pdb_str)


mutation = "V10N" #@param {type:"string"}
pos = mutation[1:-1]

show_sidechains = False #@param {type:"boolean"}
show_mainchains = False #@param {type:"boolean"}

if show_sidechains:
    BB = ['C','O','N']
    view.addStyle({'and':[{'resn':["GLY","PRO"],'invert':True},{'atom':BB,'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"GLY"},{'atom':'CA'}]},
                  {'sphere':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
    view.addStyle({'and':[{'resn':"PRO"},{'atom':['C','O'],'invert':True}]},
                  {'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})
if show_mainchains:
  BB = ['C','O','N','CA']
  view.addStyle({'atom':BB},{'stick':{'colorscheme':f"WhiteCarbon",'radius':0.3}})

i = 0
for line in pdb_str.split("\n"):
    split = line.split()
    if len(split) == 0 or split[0] != "ATOM":
        continue
    if split[5] == pos:
        color = "yellow"
        view.addStyle({'model': -1, 'serial': i+1}, {"cartoon": {'color': color}})
        view.addStyle({'model': -1, 'serial': i+1}, {"stick": {"colorscheme": "yellowCarbon"}})
    else:
        # print(split)
        color = "green"
        view.addStyle({'model': -1, 'serial': i+1}, {"cartoon": {'color': color}})
    idx = int(split[1])


    i += 1
view.zoomTo()
view.show()