In [None]:
import torch
import esm
import pandas as pd
import numpy as np
from fuzzywuzzy import fuzz
from umap import UMAP
import plotly.express as px
# pip install tables
%matplotlib inline

In [None]:
# retrieve all proteins from the dataset
#df = pd.read_csv('USE__mutant_seqswoscore.csv', index_col=0)
df = pd.read_csv('skempiwmutants_nanincl.csv', index_col=0)
df = df[['#Pdb','Mutation(s)_PDB', 'Affinity_mut_parsed','Affinity_wt_parsed','Protein 1', 'Protein 2', 'wild_seq1', 'wild_seq2', 'mutant_seq']]
df

In [None]:
# if there's ":" in the column, just drop first, indicating multi-chain situation
ignore = (df['wild_seq1'].str.contains(':')) | (df['wild_seq2'].str.contains(':')) | (df['mutant_seq'].str.contains(':'))
reduced_df = df[~ignore]
reduced_df

In [None]:
# check and remove if there's any duplicates
reduced_df.drop_duplicates(inplace=True)
df_ = reduced_df.copy()
# shouldn't drop na, na means lose binding ability in mutants
#df_.dropna(inplace=True)
df_['wild_seq_1'] = df_['wild_seq1']
df_['wild_seq_2'] = df_['wild_seq2']
df_['Protein_1'] = df_['Protein 1']
df_['Protein_2'] = df_['Protein 2']
df_ = df_.reset_index(drop=True)
df_

In [None]:
"""
    This chunk should re-organize the dataframe so that:
    1. all the sequence listed in wt_seq2 will have positions mutated
    2. if the wt_seq2 and wt_seq1 flipped, so well columns [Protein1, Protein2]
"""
for index, row in df_.iterrows():
    if fuzz.ratio(row['wild_seq1'], row['mutant_seq']) > fuzz.ratio(row['wild_seq2'], row['mutant_seq']) :
        # then the two sequences are similar
        # flip the sequence 1 to sequence 2
            df_.at[index, 'wild_seq_1'] = row['wild_seq2']
            df_.at[index, 'wild_seq_2'] = row['wild_seq1']
            df_.at[index, 'Protein_1'] = row['Protein 2']
            df_.at[index, 'Protein_2'] = row['Protein 1']
    elif fuzz.ratio(row['wild_seq1'], row['mutant_seq'])  < fuzz.ratio(row['wild_seq2'], row['mutant_seq']) :
        pass
    else:
        print(index, fuzz.ratio(row['wild_seq1'], row['mutant_seq']), fuzz.ratio(row['wild_seq2'], row['mutant_seq']))
        print('mutate both sequences?')
df_

In [None]:
df_.drop(columns=['Protein 1', 'Protein 2', 'wild_seq1', 'wild_seq2'], inplace=True)
cols = ['#Pdb', 'Mutation(s)_PDB', 'Affinity_mut_parsed', 'Affinity_wt_parsed','Protein_1',
        'Protein_2','wild_seq_1','wild_seq_2', 'mutant_seq']
df_ = df_[cols]
df_

In [None]:
# Generate sequence embeddings to the proteins
seqs_wt1 = df_.wild_seq_1.values.tolist()
seqs_wt2 = df_.wild_seq_2.values.tolist()
seqs_mut = df_.mutant_seq.values.tolist()
seqs_wt1 = set(seqs_wt1)
seqs_wt2 = set(seqs_wt2)
seqs_mut = set(seqs_mut)
seqs_mut

In [None]:
"""lazy to write function, may need to modify in the future"""
seqs_labeled_wt1 = []
count = 0
for seq in seqs_wt1:
    seqs_labeled_wt1.append(tuple((str('seq' + str(count)), seq)))
    count += 1
seqs_labeled_wt2 = []
count = 0
for seq in seqs_wt2:
    seqs_labeled_wt2.append(tuple((str('seq' + str(count)), seq)))
    count += 1
seqs_labeled_mut = []
count = 0
for seq in seqs_mut:
    seqs_labeled_mut.append(tuple((str('seq' + str(count)), seq)))
    count += 1

In [None]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()

In [None]:
# alternative way to generate batches
from torch.utils.data import TensorDataset
from esm import Alphabet, FastaBatchedDataset
batch_size = 1000
dataset = FastaBatchedDataset(list(zip(*seqs_labeled_wt1))[0], list(zip(*seqs_labeled_wt1))[1])
batches = dataset.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader = torch.utils.data.DataLoader(dataset, collate_fn=Alphabet.from_architecture("roberta_large").get_batch_converter(),
            batch_sampler=batches, pin_memory=True)
dataset_seq2 = FastaBatchedDataset(list(zip(*seqs_labeled_wt2))[0], list(zip(*seqs_labeled_wt2))[1])
batches_seq2 = dataset_seq2.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader_seq2 = torch.utils.data.DataLoader(dataset_seq2, collate_fn=Alphabet.from_architecture("roberta_large").get_batch_converter(),
            batch_sampler=batches_seq2, pin_memory=True)
dataset_mut = FastaBatchedDataset(list(zip(*seqs_labeled_mut))[0], list(zip(*seqs_labeled_mut))[1])
batches_mut = dataset_mut.get_batch_indices(batch_size, extra_toks_per_seq=1)
data_loader_mut = torch.utils.data.DataLoader(dataset_mut, collate_fn=Alphabet.from_architecture("roberta_large").get_batch_converter(),
            batch_sampler=batches_mut, pin_memory=True)

In [None]:
torch.cuda.empty_cache()
if torch.cuda.is_available():
    model = model.cuda()
    print('Transferred model to GPU')

In [None]:
#QC
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    print(batch_idx,labels)

In [None]:
representation_store_dict = {}
for batch_idx, (labels, strs, toks) in enumerate(data_loader):
    if torch.cuda.is_available():
        toks = toks.to(device='cuda', non_blocking=True)
    with torch.no_grad():
        results = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
        #results = model(toks, repr_layers = [33], return_contacts = True)['logits']
    #print(results.shape)
    results_cpu = results.to(device='cpu')
    for i, str_ in enumerate(strs):
        representation_store_dict[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()

In [None]:
# take the average of the representations of the proteins
# for umaps, along axis 0
sequence_embeddings = {key: np.mean(value, axis=0, keepdims=True) for key, value in representation_store_dict.items()}
#print({key: value.shape for key, value in sequence_embeddings.items()})
sequence_embeddings

In [None]:
representation_store_dict_seq2 = {}
for batch_idx, (labels, strs, toks) in enumerate(data_loader_seq2):
    if torch.cuda.is_available():
        toks = toks.to(device='cuda', non_blocking=True)
    with torch.no_grad():
        results = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
        #results = model(toks, repr_layers = [33], return_contacts = True)['logits']
    results_cpu = results.to(device='cpu')
    for i, str_ in enumerate(strs):
        representation_store_dict_seq2[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()
sequence_embeddings_seq2 = {key: np.mean(value, axis=0, keepdims=True) for key, value in representation_store_dict_seq2.items()}
representation_store_dict_mut = {}
for batch_idx, (labels, strs, toks) in enumerate(data_loader_mut):
    if torch.cuda.is_available():
        toks = toks.to(device='cuda', non_blocking=True)
    with torch.no_grad():
        results = model(toks, repr_layers = [33], return_contacts = True)['representations'][33]
        #results = model(toks, repr_layers = [33], return_contacts = True)['logits']
    results_cpu = results.to(device='cpu')
    for i, str_ in enumerate(strs):
        representation_store_dict_mut[str_] = results_cpu[i, 1: (len(strs[i])+1)].numpy()
sequence_embeddings_mut = {key: np.mean(value, axis=0, keepdims=True) for key, value in representation_store_dict_mut.items()}

In [None]:
def update_embeddings(row, embedding_dict):
    """
    add embeddings to the metadata column.
    cannot do the reverse, because due to mislabel, several different protein names share the same sequences
    but as long as sequences are correct, so will the embeddings
    """
    for key, value in embedding_dict.items():
        if row == key:
            return value
df_['wild_seq_1_embeddings'] = df_['wild_seq_1'].apply(update_embeddings, embedding_dict=sequence_embeddings)
df_['wild_seq_2_embeddings'] = df_['wild_seq_2'].apply(update_embeddings, embedding_dict=sequence_embeddings_seq2)
df_['mutant_seq_embeddings'] = df_['mutant_seq'].apply(update_embeddings, embedding_dict=sequence_embeddings_mut)
df_

In [None]:
df_.to_hdf('./outputs/dataframes/proteins_embeddings_meta.hdf', key='df', mode='w')

In [None]:
# extract meta information to draw umaps
# need to reorganize dataframe to make the features clear for each protein sequence
ppi = df_['Protein_1'] + '---' + df_['Protein_2']
prot_1 = df_['Protein_1']
prot_2 = df_['Protein_2']
prot_mut = df_['Protein_2']+'_'+'mut'
seq_1 = df_['wild_seq_1']
seq_2 = df_['wild_seq_2']
seq_mut = df_['mutant_seq']
embed_1 = df_['wild_seq_1_embeddings']
embed_2 = df_['wild_seq_2_embeddings']
embed_mut = df_['mutant_seq_embeddings']
affinity_wt = df_['Affinity_wt_parsed']
affinity_mut = df_['Affinity_mut_parsed']
label_wt_1 = pd.Series(['wt1']*len(prot_1))
label_wt_2 = pd.Series(['wt2']*len(prot_2))
label_mut = pd.Series(['mut']*len(prot_mut))
pdbs = df_['#Pdb']
mut = df_['Mutation(s)_PDB']
mut_status_wt = pd.Series([np.nan]*len(prot_1))

In [None]:
df_meta = pd.DataFrame({
    'PDB': pd.concat([pdbs, pdbs, pdbs], ignore_index=True),
    'Protein': pd.concat([prot_1, prot_2, prot_mut], ignore_index=True),
    'Mutation': pd.concat([mut_status_wt, mut_status_wt, mut], ignore_index=True),
    'Sequence': pd.concat([seq_1, seq_2, seq_mut], ignore_index=True),
    'PPI': pd.concat([ppi, ppi, ppi], ignore_index=True),
    'Label': pd.concat([label_wt_1, label_wt_2, label_mut], ignore_index=True),
    'Affinity': pd.concat([affinity_wt, affinity_mut, affinity_mut], ignore_index=True),
    'Embedding': pd.concat([embed_1, embed_2, embed_mut], ignore_index=True),
})
df_meta

In [None]:
expanded_embeddings = df_meta['Embedding'].apply(lambda x: pd.Series(x[0]))
df_umap = pd.concat([df_meta, expanded_embeddings], axis=1)
df_umap.drop(['Embedding'], axis=1, inplace=True)
df_umap

In [None]:
df_umap.drop_duplicates(inplace=True, ignore_index=True)
# weird duplicates should be dropped under the affinity column
# keep the first encounter first
# seems like a lot of wts mistakenly have affinities of mutants
#df_umap= df_umap.drop_duplicates(subset='Affinity', keep='first', ignore_index=True)
df_umap.drop_duplicates(subset=['Sequence','Affinity'], keep='first')
df_umap

In [None]:
features = df_umap.loc[:,0:]
features

In [None]:
umap_2d = UMAP(n_components=2, init='random', random_state=0)
proj_2d = umap_2d.fit_transform(features)
results = {'umap1': proj_2d[:, 0], 'umap2': proj_2d[:, 1]}
umap_results = pd.DataFrame(data=results)
umap_results

In [None]:
df_umap_ = pd.concat([df_umap, umap_results], axis=1)
df_umap_

In [None]:
color_dict = {'wt1':'#72B7B2', 'wt2':'#54A24B', 'mut':'#E45756'}
fig= px.scatter(
    df_umap_,
    x = 'umap1',
    y = 'umap2',
    color='Label',
    color_discrete_map=color_dict,
    hover_name='Protein',
    hover_data={
        'Protein': False,
        'PDB': True,
        'Label': True,
        'PPI': True,
        'Mutation':True,
        'Affinity':True,
        'umap1':False,
        'umap2':False
    }
)
fig.update_layout(template='simple_white',
                  title='SKEMPI Protein-Protein embeddings',
                  title_x=0.5
                  )
fig.show()

In [None]:
fig.write_html("./outputs/figures/skempi_ppi_plotly.html")