In [1]:
import torchdrug

In [2]:
torchdrug.__version__

'0.2.0'

In [6]:
from torchdrug import datasets

BASE_PATH = "/home/ec2-user/esm/"


In [7]:
from torchdrug import transforms
from torchdrug import datasets

truncate_transform = transforms.TruncateProtein(max_length=1024, random=False)
protein_view_transform = transforms.ProteinView(view="residue")
transform = transforms.Compose([truncate_transform, protein_view_transform])
# dataset = datasets.SubcellularLocalization(SUBCELLULAR_PATH, atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)
dataset = datasets.Fluorescence(BASE_PATH, atom_feature=None, bond_feature=None, residue_feature="default", transform=transform)

03:13:29   Downloading http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/fluorescence.tar.gz to /home/ec2-user/esm/fluorescence.tar.gz
03:13:30   Extracting /home/ec2-user/esm/fluorescence.tar.gz to /home/ec2-user/esm


Constructing proteins from sequences: 100%|██████████| 54025/54025 [01:12<00:00, 747.43it/s]


In [8]:
dataset[0]

{'graph': Protein(num_atom=0, num_bond=0, num_residue=237),
 'log_fluorescence': 3.8237006664276123}

In [9]:
train_set, valid_set, test_set = dataset.split()

In [10]:
print("The label of first sample: ", dataset[0][dataset.target_fields[0]])
print("train samples: %d, valid samples: %d, test samples: %d" % (len(train_set), len(valid_set), len(test_set)))

The label of first sample:  3.8237006664276123
train samples: 21446, valid samples: 5362, test samples: 27217


In [42]:
prot_seq = dataset[0]['graph'].to_sequence().replace('.', '')

In [16]:
import pandas as pd
from tqdm import tqdm

seq = []
for item in tqdm(train_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'train'})

for item in tqdm(valid_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'val'})

for item in tqdm(test_set):
    aa = item['graph'].to_sequence().replace('.', '')
    lf = item['log_fluorescence']
    seq.append({'seq': aa, 'loc': lf, 'split': 'test'})

seq = pd.DataFrame(seq)

100%|██████████| 21446/21446 [00:11<00:00, 1885.85it/s]
100%|██████████| 5362/5362 [00:02<00:00, 1949.96it/s]
100%|██████████| 27217/27217 [00:14<00:00, 1939.00it/s]


In [17]:
seq.to_csv('protein_lf.csv')

## Train flourescence model

In [4]:
import pandas as pd

seq = pd.read_csv('protein_lf.csv', index_col=0)

In [26]:
seq

Unnamed: 0,seq,loc,split
0,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,3.823701,train
1,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,3.752084,train
2,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,3.540156,train
3,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,3.691572,train
4,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,3.688143,train
...,...,...,...
54020,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,1.565922,test
54021,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,1.532945,test
54022,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,1.529521,test
54023,SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFI...,1.301030,test


In [5]:
seq['seq'].iloc[0]

'SKGEELFTGVVPILVELDGDVNGHKFSVSGEGEGDATYGKLTLKFICTTGKLPVPWPTLVTTLSYGVQCFSRYPDHMKQHDFFKSAMPEGYVQERTIFFKDDGNYKTRAEVKFEGDTLVNRIELKGIDFKEDGNILGHKLEYNYNSHNVYIMADKQKNGIKVNFKIRHKIEDGSVQLADHYQQNTPIGDGPVLLPDNHYLSTQSALSKDPNEKRDHMVLLEFVTAAGITHGMDERYK'

In [6]:
from sklearn.preprocessing import OneHotEncoder
import numpy as np

conversion = 'ARNDCQEGHILKMFPSTWYVX'
amino_acids = np.array([a for a in conversion])
aa = seq['seq'].iloc[0]
onehot_encoder = OneHotEncoder(sparse=False, categories=[amino_acids])


In [7]:
sequence = seq['seq'].iloc[0]
sequence_array = np.array(list(sequence)).reshape(-1, 1)
onehot_encoded = onehot_encoder.fit_transform(sequence_array)

In [8]:
from tqdm import tqdm
import numpy as np

embed = np.zeros((len(seq), 237, 21))

for i, sequence in tqdm(enumerate(seq['seq'])):
    sequence_array = np.array(list(sequence)).reshape(-1, 1)
    onehot_encoded = onehot_encoder.fit_transform(sequence_array)
    embed[i, :onehot_encoded.shape[0]] = onehot_encoded

54025it [00:43, 1248.88it/s]


In [9]:
embed = embed.reshape(len(seq), -1)

In [10]:
embed.shape

(54025, 4977)

In [11]:
index = (seq['split'] == 'train').values
X_train = embed[index]
y_train = seq[index]['loc']

index = (seq['split'] == 'val').values
X_val = embed[index]
y_val = seq[index]['loc']

index = (seq['split'] == 'test').values
X_test = embed[index]
y_test = seq[index]['loc']

In [12]:
X_train.shape

(21446, 4977)

In [13]:
from sklearn.linear_model import Ridge

alpha = 0.5
clf = Ridge(alpha=alpha)
clf.fit(X_train, y_train)

Ridge(alpha=0.5)

In [14]:
y_pred = clf.predict(X_test)

In [16]:
from scipy.stats import spearmanr

spearmanr(y_test, y_pred)

SpearmanrResult(correlation=0.6788691646387355, pvalue=0.0)

In [18]:
# Extract the learned parameters
weights = clf.coef_
intercept = clf.intercept_

In [19]:
# Save the parameters
np.save('weights_logf.npy', weights)
np.save('intercept_logf.npy', intercept)

In [22]:
import torch

# Load the parameters
weights = np.load('weights_logf.npy')
intercept = np.load('intercept_logf.npy')

# Convert to PyTorch tensors
weights_torch = torch.from_numpy(weights)
intercept_torch = torch.from_numpy(np.array([intercept]))

# Define a linear layer
linear_layer = torch.nn.Linear(weights.shape[0], 1)

# Set the weights and bias
with torch.no_grad():  # We don't want these operations to be tracked by the autograd
    linear_layer.weight.data = weights_torch
    linear_layer.bias.data = intercept_torch

In [23]:
linear_layer

Linear(in_features=4977, out_features=1, bias=True)

## Define potential

In [25]:
import os, sys
from sklearn.preprocessing import OneHotEncoder
import numpy as np



# TEMPLATE CLASS
class Potential:
    
    def get_gradients(seq):
        '''
            EVERY POTENTIAL CLASS MUST RETURN GRADIENTS
        '''
        
        sys.exit('ERROR POTENTIAL HAS NOT BEEN IMPLEMENTED')

In [None]:
class GFP_log_flourescence(Potential):
    """
    Potential for GFP log flourescence
    """    
    def __init__(self, args, features, potential_scale, DEVICE):
        weights = args['weights']
        intercept = args['intercept']
        
        # Convert to PyTorch tensors
        weights_torch = torch.from_numpy(weights)
        intercept_torch = torch.from_numpy(np.array([intercept]))

        # Define a linear layer
        linear_layer = torch.nn.Linear(weights.shape[0], 1)

        # Set the weights and bias
        with torch.no_grad():  # We don't want these operations to be tracked by the autograd
            linear_layer.weight.data = weights_torch
            linear_layer.bias.data = intercept_torch        
        
        self.linear_layer = linear_layer.to(DEVICE)
        self.potential_scale = potential_scale
        self.sequence_length = 237
        
    def get_gradients(self, seq):
        """
        Calculate gradients with respect to log F


        Arguments
        ---------
        seq : tensor
            L X 21 logits after saving seq_out from xt

        Returns
        -------
        gradients : list of tensors
            gradients of seq with respect to flourescence
        """
        soft_seq = torch.softmax(seq, dim=1)

        if soft_seq.shape[0] > self.sequence_length:
            soft_seq = soft_seq[:self.sequence_length]

        if soft_seq.shape[0] < self.sequence_length:
            zeros = torch.zeros(self.sequence_length - soft_seq.shape[0], soft_seq.shape[1])
            soft_seq = torch.cat((soft_seq, zeros), 0)

        score = linear_layer(soft_seq.reshape(-1))
        score.backward()
        gradients = soft_seq.grad

        return gradients * self.potential_scale