In [1]:
import pandas as pd
import pickle as pkl
import os
import torch
import numpy as np


Read MMSeqs2 Clusters and remove all clusters with MPNN data in it

In [2]:
clusters_unfiltered = pd.read_csv('./ProteinMPNN_HotProtein_mmseqs2/ProteinMPNN_HotProtein.Res_cluster.tsv',sep='\t',header=None,names=['Cluster','UniProt'])
print(clusters_unfiltered['Cluster'].unique().shape)
with open('./ProteinMPNN_Data/ProteinMPNN_TrainingData_Unique_UniProtAccessions.txt') as f:
    MPNN_data = pd.Series(f.readlines()).str.rstrip('\n')

clusters_unfiltered['MPNN'] = clusters_unfiltered['UniProt'].isin(MPNN_data).astype(int) # 1 if the UniProt is in MPNN_data, 0 if not
clusters = clusters_unfiltered.groupby('Cluster').filter(lambda x: (x['MPNN'] == 0).all()) # Remove all clusters that have any MPNN proteins
print(clusters['Cluster'].unique().shape)

# 92124 total clusters (MPNN and HotProtein combined)
# 72695 clusters with only HotProtein proteins


(92124,)
(72695,)


For the remaining clusters, remove observations that don't have an accompanied AF2 structure and add temperature class info

In [3]:
# Adding temperature class labels
with open('S/S_target_classification.pkl', 'rb') as f:
    S_target_classification = pkl.load(f)
S_target_train_classification = pd.DataFrame({key: S_target_classification[key] for key in ['train_names', 'train_labels']})
S_target_test_classification = pd.DataFrame({key: S_target_classification[key] for key in ['test_names', 'test_labels']})
S_target_classification_concat = pd.concat([S_target_train_classification.rename(columns={'train_names': 'names', 'train_labels': 'labels'}), 
                             S_target_test_classification.rename(columns={'test_names': 'names', 'test_labels': 'labels'})])
print(S_target_classification_concat.shape)

# Saving proteins that have AF2 structures
S_target_AF2 = os.listdir("/stor/work/Ellington/ProteinMPNN/HotProtein/S/AF2/PDB")
S_target_AF2 = pd.Series(S_target_AF2)
S_target_AF2 = S_target_AF2.str.split('-').str[1]
S_target_names_not_in_AF2 = S_target_classification_concat[~S_target_classification_concat['names'].isin(S_target_AF2)]['names']
S_target_classification_concat = S_target_classification_concat[~S_target_classification_concat['names'].isin(S_target_names_not_in_AF2)]
print(S_target_classification_concat.shape)

# 182305 total HotProteins
# 181914 total HotProteins with a mapped structure

print(clusters['Cluster'].unique().shape)
clusters = clusters.merge(S_target_classification_concat, left_on='UniProt', right_on='names', how='left')
clusters.drop(columns=['names','MPNN'],inplace=True)
clusters.dropna(subset=['labels'], inplace=True)
print(clusters['Cluster'].unique().shape)

# 72695 clusters before filtering
# 72531 clusters after filtering


### Adding temperature class labels to clusters_unfiltered, for reference
clusters_unfiltered = clusters_unfiltered.merge(S_target_classification_concat, left_on='UniProt', right_on='names', how='left')
clusters_unfiltered.drop(columns=['names','MPNN'], inplace=True)
clusters_unfiltered['labels'].fillna(-1, inplace=True)



(182305, 2)
(181914, 2)
(72695,)
(72531,)


Selecting one representative UniProt accession per cluster

For each temp range in s2c5:
* Make a 80/10/10 train/test/valid split
* Load pt file information and create list.csv, train_clusters, test_clusters.txt, valid_clusters.txt
* Add ProteinMPNN validation clusters to the valid_clusters.txt to gage how well the new model performs on generic proteins during training.


In [22]:
for i in range(5):
    temp_range = clusters[clusters['labels'] == i]
    temp_range = temp_range.groupby('Cluster')
    temp_range = temp_range.first()
    temp_range = pd.DataFrame(temp_range.values, columns=temp_range.columns)
    print(temp_range.shape)
    temp_range = temp_range.sample(frac=1, random_state=1) # This shuffles the rows
    temp_range = temp_range.reset_index(drop=True) # Resetting the index
    temp_range_train = temp_range.iloc[:int(temp_range.shape[0]*0.8)] # Taking the first 80% of the rows
    temp_range_valid = temp_range.iloc[int(temp_range.shape[0]*0.8):] # Taking the last 20% of the rows

    hotprotein_list_csv = pd.DataFrame(columns=('CHAINID','DEPOSITION','RESOLUTION','HASH','CLUSTER','SEQUENCE'))
    for pdb in temp_range['UniProt']:
        pt = torch.load('/stor/work/Ellington/ProteinMPNN/HotProtein/S/PDB_pt/'+pdb+'.pt')
        seq = pt['seq'][0][0]
        chainID = pt['id']+'_A'
        deposition = '2017-02-27'
        resolution = 1.0
        hash = str(len(hotprotein_list_csv)).zfill(6)
        cluster = len(hotprotein_list_csv)
        hotprotein_list_csv.loc[len(hotprotein_list_csv)] = [chainID, deposition, resolution, hash, cluster, seq]

    hotprotein_list_csv.to_csv('./Models/class_'+str(i)+'/list.csv',sep=',',header=False, index=False)
    np.savetxt('./Models/class_'+str(i)+'/train_clusters.txt', temp_range_train.index.values, fmt='%s')
    #np.savetxt('./Models/class_'+str(i)+'/test_clusters.txt', test_data, fmt='%s')
    np.savetxt('./Models/class_'+str(i)+'/valid_clusters.txt', temp_range_valid.index.values, fmt='%s')


(5103, 2)
(17985, 2)
(17321, 2)
(31044, 2)
(14676, 2)


In [9]:
s2c2 = [[0,1,2],[3,4]]
for i in range(2):
    temp_range = clusters[clusters['labels'].isin(s2c2[i])]
    temp_range = temp_range.groupby('Cluster')
    temp_range = temp_range.first()
    temp_range = pd.DataFrame(temp_range.values, columns=temp_range.columns)
    print(temp_range.shape)
    temp_range = temp_range.sample(frac=1, random_state=1) # This shuffles the rows
    temp_range = temp_range.reset_index(drop=True) # Resetting the index
    temp_range_train = temp_range.iloc[:int(temp_range.shape[0]*0.8)] # Taking the first 80% of the rows
    temp_range_valid = temp_range.iloc[int(temp_range.shape[0]*0.8):] # Taking the last 20% of the rows

    hotprotein_list_csv = pd.DataFrame(columns=('CHAINID','DEPOSITION','RESOLUTION','HASH','CLUSTER','SEQUENCE'))
    for pdb in temp_range['UniProt']:
        pt = torch.load('/stor/work/Ellington/ProteinMPNN/HotProtein/S/PDB_pt/'+pdb+'.pt')
        seq = pt['seq'][0][0]
        chainID = pt['id']+'_A'
        deposition = '2017-02-27'
        resolution = 1.0
        hash = str(len(hotprotein_list_csv)).zfill(6)
        cluster = len(hotprotein_list_csv)
        hotprotein_list_csv.loc[len(hotprotein_list_csv)] = [chainID, deposition, resolution, hash, cluster, seq]

    hotprotein_list_csv.to_csv('./Models/class_'+'_'.join(map(str, s2c2[i]))+'/list.csv',sep=',',header=False, index=False)
    np.savetxt('./Models/class_'+'_'.join(map(str, s2c2[i]))+'/train_clusters.txt', temp_range_train.index.values, fmt='%s')
    #np.savetxt('./Models/class_'+'_'.join(map(str, s2c2[i]))+'/test_clusters.txt', test_data, fmt='%s')
    np.savetxt('./Models/class_'+'_'.join(map(str, s2c2[i]))+'/valid_clusters.txt', temp_range_valid.index.values, fmt='%s')

(36359, 2)
(42727, 2)
