In [1]:
import torchdrug

In [2]:
torchdrug.__version__

'0.2.0'

In [6]:
from torchdrug import datasets

BASE_PATH = "/home/ec2-user/esm/"


In [7]:
from torchdrug import transforms
from torchdrug import datasets

truncate_transform = transforms.TruncateProtein(max_length=1024, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])
# dataset = datasets.SubcellularLocalization(SUBCELLULAR_PATH, atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)
dataset = datasets.Fluorescence(BASE_PATH, atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)

03:13:29   Downloading http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz to /home/ec2-user/esm/fluorescence.tar.gz
03:13:30   Extracting /home/ec2-user/esm/fluorescence.tar.gz to /home/ec2-user/esm


Constructing proteins from sequences: 100%|██████████| 54025/54025 [01:12<00:00, 747.43it/s]


In [8]:
dataset[0]

{'graph': Protein(num_atom=0, num_bond=0, num_residue=237),
 'log_fluorescence': 3.8237006664276123}

In [9]:
train_set, valid_set, test_set = dataset.split()

In [10]:
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))

The label of first sample:  3.8237006664276123
train samples: 21446, valid samples: 5362, test samples: 27217


In [42]:
prot_seq = dataset[0]['graph'].to_sequence().replace('.', '')

In [16]:
import pandas as pd
from tqdm import tqdm

seq = []
for item in tqdm(train_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'train'})

for item in tqdm(valid_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'val'})

for item in tqdm(test_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'test'})

seq = pd.DataFrame(seq)

100%|██████████| 21446/21446 [00:11<00:00, 1885.85it/s]
100%|██████████| 5362/5362 [00:02<00:00, 1949.96it/s]
100%|██████████| 27217/27217 [00:14<00:00, 1939.00it/s]


In [17]:
seq.to_csv('protein_lf.csv')

## Train flourescence model

In [18]:
import pandas as pd

seq = pd.read_csv('protein_lf.csv', index_col=0)

In [28]:
seq['seq'].iloc[0]

'SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHKIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDERYK'

In [56]:
from sklearn.preprocessing import OneHotEncoder
import numpy as np

conversion = 'ARNDCQEGHILKMFPSTWYVX'
amino_acids = np.array([a for a in conversion])
aa = seq['seq'].iloc[0]
onehot_encoder = OneHotEncoder(sparse=False, categories=[amino_acids])


In [57]:
sequence = seq['seq'].iloc[0]
sequence_array = np.array(list(sequence)).reshape(-1, 1)
onehot_encoded = onehot_encoder.fit_transform(sequence_array)

In [86]:
from tqdm import tqdm

embed = np.zeros((len(seq), 237, 21))

for i, sequence in tqdm(enumerate(seq['seq'])):
    sequence_array = np.array(list(sequence)).reshape(-1, 1)
    onehot_encoded = onehot_encoder.fit_transform(sequence_array)
    embed[i, :onehot_encoded.shape[0]] = onehot_encoded

54025it [00:43, 1241.75it/s]


In [87]:
embed = embed.reshape(len(seq), -1)

In [88]:
embed.shape

(54025, 4977)

In [89]:
index = (seq['split'] == 'train').values
X_train = embed[index]
y_train = seq[index]['loc']

index = (seq['split'] == 'val').values
X_val = embed[index]
y_val = seq[index]['loc']

index = (seq['split'] == 'test').values
X_test = embed[index]
y_test = seq[index]['loc']

In [90]:
X_train.shape

(21446, 4977)

In [116]:
from sklearn.linear_model import Ridge

alpha = 0.5
clf = Ridge(alpha=alpha)
clf.fit(X_train, y_train)

Ridge(alpha=0.5)

In [120]:
y_pred = clf.predict(X_test)

In [121]:
from scipy.stats import spearmanr

spearmanr(y_test, y_pred)

SpearmanrResult(correlation=0.6788691646387355, pvalue=0.0)