In [None]:
pip install pandas

# Preprocessing

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

## reading hgvs4variation_filtered.csv

In [None]:
hgvs_var = pd.read_csv("hgvs4variation_filtered.csv")

## Reading variant summary and creating the dictionnary

In [None]:
variant_summ = pd.read_csv('variant_summary_filtered.csv')

## creating dictionnaries

In [None]:
gene_names = set([n for n in hgvs_var["Symbol"]  if n != 'TERC' ])
dic = {gene: {} for gene in gene_names}

IDs = []
for row in hgvs_var.itertuples():
    IDs.append(row.AlleleID)
    
AlleleIDs = set(IDs)
    

In [None]:
daatol = {
    "Cys": "C",
    "Asp": "D",
    "Ser": "S",
    "Gln": "Q",
    "Lys": "K",
    "Ile": "I",
    "Pro": "P",
    "Thr": "T",
    "Phe": "F",
    "Asn": "N",
    "Gly": "G",
    "His": "H",
    "Leu": "L",
    "Arg": "R",
    "Trp": "W",
    "Ala": "A",
    "Val": "V",
    "Glu": "E",
    "Tyr": "Y",
    "Met": "M",
}

### gene sequences dictionnary and alphabet

In [None]:
gene_sequences = pd.read_csv("gene_seqs.csv")
seq_dic = {}
for row in gene_sequences.itertuples():
    seq_dic[row.Gene] = row.Sequence
    
print(len(seq_dic['POLE']))

### Dic of dic of dic: (Gene) Symbol: {AlleleID : {Protein Change : Clinical Significance}}

In [None]:
import re 
regex = re.compile('^[A-Za-z]{3}[0-9]+[A-Za-z]{3}$')

for row in hgvs_var.itertuples():
    if (row.Symbol in gene_names):
        #exclude alleles that do not make a change in protein
        if (row.ProteinChange != '-'):
            proteinChange = row.ProteinChange.split('.')[1]
            #make sure proteinChange matches the format 3 letters + number + 3 letters
            if (regex.match(proteinChange)):
                if (proteinChange[:3] != 'Ter'):
                    dic[row.Symbol][row.AlleleID] = {"Change" : proteinChange}

for row in variant_summ.itertuples():
    if (row.GeneSymbol in gene_names):
        if (row.AlleleID in dic[row.GeneSymbol]):
            dic[row.GeneSymbol][row.AlleleID]["ClinicalSignificance"] = row.ClinicalSignificance


In [None]:
num_correct = 0
num_incorrect = 0  
pathogenic = set(['Pathogenic','Pathogenic/Likely pathogenic; risk factor', 'Likely pathogenic'])
benign = set(['Benign', 'Benign/Likely benign', 'Likely benign'])
vus = set(['Uncertain significance', 'Conflicting interpretations of pathogenicity'])
pathogenicity = {gene: {} for gene in gene_names}

In [None]:
for g in dic:
    num_patho = 0
    num_benign = 0
    num_vus = 0
    for a in dic[g]:
        wild_type = dic[g][a]['Change'][:3] 
        loc = int(dic[g][a]['Change'][3 : -3])
        if(loc <= 1022):
            if(len(seq_dic[g]) > loc-1):
                if('ClinicalSignificance' in dic[g][a].keys()):
                    #print(dic[g][a])
                    if(dic[g][a]['ClinicalSignificance'] in pathogenic):
                        num_patho += 1
                        pathogenicity[g]['Num of Pathogenic'] = num_patho
                        print(pathogenicity)

                    if(dic[g][a]['ClinicalSignificance'] in benign):
                        num_benign += 1
                        pathogenicity[g]['Num of Benign'] = num_benign

                    elif(dic[g][a]['ClinicalSignificance'] in vus): 
                        num_vus += 1
                        pathogenicity[g]['Num of VUS'] = num_vus
                    if (daatol[wild_type] == seq_dic[g][loc-1]):
                        num_correct += 1
                    else: 
                        num_incorrect += 1

        if(num_patho == 0):
               pathogenicity[g]['Num of Pathogenic'] = num_patho
        if(num_benign == 0):
            pathogenicity[g]['Num of Benign'] = num_benign
        if(num_vus == 0):
            pathogenicity[g]['Num of VUS'] = num_vus
            #else:
             #   raise ValueError("No clinical significance for sample", dic[g][a])

        
print("Number of correctly placed amino acids:",
      num_correct,'\n',"Number of misplaced amino acids:",
      num_incorrect)

## Visualisation


In [None]:
print(sum(pathogenicity[ge]['Num of VUS'] for ge in pathogenicity), "VUSs")
print(sum(pathogenicity[ge]['Num of Pathogenic'] for ge in pathogenicity), 'Pathogenic')
print(sum(pathogenicity[ge]['Num of Benign'] for ge in pathogenicity), 'Benign')

X = np.arange(len(pathogenicity))
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.set_title('Distribution of Pathogenicity per gene')
ax.set_ylabel('Number of cases')

ax.bar(X + 0.00, [pathogenicity[ge]['Num of VUS'] for ge in pathogenicity], color ='#fcff42', width = 0.25)
ax.bar(X + 0.25, [pathogenicity[ge]['Num of Pathogenic'] for ge in pathogenicity], color = '#426eff', width = 0.25)
ax.bar(X + 0.50, [pathogenicity[ge]['Num of Benign'] for ge in pathogenicity], color = 'r', width = 0.25)
plt.legend(['VUS','Benign', 'Pathogenic'])   
    
plt.show()

In [None]:
#CDK4

In [None]:
X = np.arange(1)
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.set_title('Distribution of Pathogenicity for ATM gene')
ax.set_ylabel('Number of cases')

ax.bar(X + 0.00, pathogenicity['ATM']['Num of VUS'], color ='#fcff42', width = 0.25)
ax.bar(X + 0.25, pathogenicity['ATM']['Num of Benign'], color = '#426eff', width = 0.25)
ax.bar(X + 0.50, pathogenicity['ATM']['Num of Pathogenic'], color = 'r', width = 0.25)
ax.set_xticks([])

plt.xlabel('ATM')
plt.legend(['VUS','Benign', 'Pathogenic'])   
    
plt.show()

In [None]:
print(pathogenicity['ATM']['Num of VUS'], "VUSs")
print(pathogenicity['ATM']['Num of Benign'], 'Benign')
print(pathogenicity['ATM']['Num of Pathogenic'], 'Pathogenic')

In [None]:
X = np.arange(1)
label = 'POLE'
fig = plt.figure()
ax = fig.add_axes([0,0,1,1])
ax.set_title('Distribution of Pathogenicity for POLE gene')
ax.set_ylabel('Number of cases')

ax.bar(X + 0.00, pathogenicity[label]['Num of VUS'], color ='#fcff42', width = 0.25)
ax.bar(X + 0.25, pathogenicity[label]['Num of Benign'], color = '#426eff', width = 0.25)
ax.bar(X + 0.50, pathogenicity[label]['Num of Pathogenic'], color = 'r', width = 0.25)
ax.set_xticks([])

plt.xlabel(label)
plt.legend(['VUS','Benign', 'Pathogenic'])   
    
plt.show()

In [None]:
print(pathogenicity[label]['Num of VUS'], "VUSs")
print(pathogenicity[label]['Num of Benign'], 'Benign')
print(pathogenicity[label]['Num of Pathogenic'], 'Pathogenic')

# The Model

### Imports

In [None]:
pip install fair-esm

In [None]:
import torch
import esm

### Utils

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.



def predict_pathogenic(
    wt_sequence, mut_idx, mut_aa, model, alphabet, batch_converter, mode="wt"
):
    if mode == "wt":
        sequence = wt_sequence
    elif mode == "mask":
        sequence = wt_sequence[:mut_idx] + "<mask>" + wt_sequence[mut_idx + 1 :]
    else:
        raise ValueError("Invalid mode")

    # Prepare sequence for the model
    data = [("protein", sequence)]
    _, _, batch_tokens = batch_converter(data)

    # Run sequence through the model
    with torch.no_grad():
        if torch.cuda.is_available():
            batch_tokens = batch_tokens.cuda()
        token_probs = torch.log_softmax(model(batch_tokens)["logits"], dim=-1)

    # Compare wild-type probability to the probability of the mutated amino acid
    wt = wt_sequence[mut_idx]
    mt = mut_aa
    wt_encoded, mt_encoded = alphabet.get_idx(wt), alphabet.get_idx(mt)

    # add 1 for BOS
    score = (
        token_probs[0, 1 + mut_idx, mt_encoded]
        - token_probs[0, 1 + mut_idx, wt_encoded]
    )
    return score.item()

# Main

In [None]:
def main():
    # IMPROTANT: Path to a folder of your choice :
    #torch.hub.set_dir(r"C:\Users\yourusername/yourfolder)

    model, alphabet = esm.pretrained.esm1v_t33_650M_UR90S_1()
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
        print("Transferred model to GPU")

    batch_converter = alphabet.get_batch_converter()

    # Loop over all sequences
    location = dic['ALK'][33123]['Change'][3:-3]
    mutation = dic['ALK'][33123]['Change'][-3:]
    wt_sequence = seq_dic['ALK']
    mut_idx = location
    mut_aa = daatol[mutation]

    score = predict_pathogenic(wt_sequence[:1022], mut_idx, mut_aa, model, alphabet, batch_converter)
    
    


if __name__ == "__main__":
    main()