In [1]:
#download the dataset

!mkdir TAPE_benchmark
%cd TAPE_benchmark
!git clone https://github.com/songlab-cal/tape.git
!wget http://s3.amazonaws.com/songlabdata/proteindata/data_pytorch/stability.tar.gz
!tar -xf stability.tar.gz


#import the libraries

from TAPE_benchmark.tape.tape.datasets import *
import numpy as np
from transformers import AutoTokenizer, EsmModel
import torch

from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
import numpy as np

## Read the dataset and create the dataloaders

Use the esm2's tokenizer to create dataset and dataloader that fit esm2's model & token format


In [2]:

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
tokenizer.get_vocab()

{'<cls>': 0,
 '<pad>': 1,
 '<eos>': 2,
 '<unk>': 3,
 'L': 4,
 'A': 5,
 'G': 6,
 'V': 7,
 'S': 8,
 'E': 9,
 'R': 10,
 'T': 11,
 'I': 12,
 'D': 13,
 'P': 14,
 'K': 15,
 'Q': 16,
 'N': 17,
 'F': 18,
 'Y': 19,
 'M': 20,
 'H': 21,
 'W': 22,
 'C': 23,
 'X': 24,
 'B': 25,
 'U': 26,
 'Z': 27,
 'O': 28,
 '.': 29,
 '-': 30,
 '<null_1>': 31,
 '<mask>': 32}

In [3]:
test_lmdb_dataset = dataset_factory('./TAPE_benchmark/stability/stability_test.lmdb', in_memory=False)
val_lmdb_dataset = dataset_factory('./TAPE_benchmark/stability/stability_valid.lmdb', in_memory=False)
train_lmdb_dataset = dataset_factory('./TAPE_benchmark/stability/stability_train.lmdb', in_memory=False)

In [4]:
print(f'number of samples in train: {len(train_lmdb_dataset)}, validation: {len(val_lmdb_dataset)}, test: {len(test_lmdb_dataset)}')

number of samples in train: 53614, validation: 2512, test: 12851


In [5]:

# create dataset for the esm model

class StabilityDataset(Dataset):
    def __init__(self, LMDB_dataset, tokenizer):
        self.tokenizer = tokenizer
        self.dataset = LMDB_dataset

    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        item = self.dataset[index]
        token_ids = np.array(self.tokenizer.encode(item['primary']))
        input_mask = np.ones_like(token_ids)
        stability_score = float(item['stability_score'])
        return token_ids, input_mask, stability_score


        

def collate_fn(batch):

    token_ids, input_mask, stability_score = tuple(zip(*batch)) 
    token_ids = torch.from_numpy(pad_sequences(np.array(token_ids), 0))
    input_mask = torch.from_numpy(pad_sequences(input_mask, 0))
    stability_true_value = torch.FloatTensor(stability_score) 
    stability_true_value = stability_true_value.unsqueeze(1)

    return {'input_ids': token_ids,
            'input_mask': input_mask,
            'targets': stability_true_value}




train_dataset = StabilityDataset(train_lmdb_dataset, tokenizer)
val_dataset = StabilityDataset(val_lmdb_dataset, tokenizer)
test_dataset = StabilityDataset(test_lmdb_dataset, tokenizer)


In [6]:
# create dataloaders
BATCHSIZE = 64
train_dataloader = DataLoader(train_dataset,batch_size=BATCHSIZE,collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset,batch_size=BATCHSIZE,collate_fn=collate_fn)
test_dataloader = DataLoader(test_dataset,batch_size=BATCHSIZE,collate_fn=collate_fn)

## Import the pretrained model and define the model

ESM2 model: Evolutionary Scale Modeling 
- Transformer protein language models
- This model was trained on masked amino acid sequences

In [8]:
pretrainedESM = EsmModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
pretrainedESM


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.bias']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


EsmModel(
  (embeddings): EsmEmbeddings(
    (word_embeddings): Embedding(33, 320, padding_idx=1)
    (dropout): Dropout(p=0.0, inplace=False)
    (position_embeddings): Embedding(1026, 320, padding_idx=1)
  )
  (encoder): EsmEncoder(
    (layer): ModuleList(
      (0): EsmLayer(
        (attention): EsmAttention(
          (self): EsmSelfAttention(
            (query): Linear(in_features=320, out_features=320, bias=True)
            (key): Linear(in_features=320, out_features=320, bias=True)
            (value): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (rotary_embeddings): RotaryEmbedding()
          )
          (output): EsmSelfOutput(
            (dense): Linear(in_features=320, out_features=320, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
          (LayerNorm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        )
        (intermediate): EsmIntermediate(
          

In [9]:
#define the fine tuning model

class StabilityPredictor(nn.Module):
    def __init__(self,
                 esm: nn.Module,
                 enc_hid_dim=320, 
                 outputs=1):
        super().__init__()

        self.esm = esm
        self.enc_hid_dim = enc_hid_dim
        self.pre_predictor = nn.Linear(self.enc_hid_dim, self.enc_hid_dim)
        self.predictor = nn.Linear(self.enc_hid_dim, outputs)
        
    def forward(self,
                seq):
        esm_output = self.esm(seq)
        last_hidden_state = esm_output.last_hidden_state ##(batch, seq_len, dim)
        pooled_output = last_hidden_state[:,0,:] ##(batch, dim)
        pooled_output = self.pre_predictor(pooled_output)
        stability = self.predictor(pooled_output) ## (batch, num_labels)

        return stability

In [10]:
def init_classification_head_weights(m: nn.Module, hidden_size=320):
    k = 1/hidden_size
    for name, param in m.named_parameters():
        if name in ["pre_predictor.weight","pre_predictor.bias","predictor.weight","predictor.bias"]:
            if 'weight' in name:
                nn.init.uniform_(param.data, a=-1*k**0.5, b=k**0.5)
            else:
                nn.init.uniform_(param.data, 0)

## Model training

In [11]:
#define hyperparameters

LR = 1e-4
N_EPOCHS = 10


#define models, move to device, and initialize weights
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = StabilityPredictor(esm=pretrainedESM).to(device)
model.apply(init_classification_head_weights)
model.to(device)

pre_predictor.weight
pre_predictor.bias
predictor.weight
predictor.bias
Model Initialized


In [13]:
#define training function

def train(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0.0
    
    for databatch in dataloader:
        input_ids, input_masks, stability = databatch['input_ids'], databatch['input_mask'], databatch['targets']
        optimizer.zero_grad()
        pred = model(input_ids.to(device))
        loss = F.mse_loss(pred, stability.to(device))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        
    return total_loss/len(dataloader)
        
        

#define evaluation function

def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for databatch in dataloader:
            input_ids, input_masks, stability = databatch['input_ids'], databatch['input_mask'], databatch['targets']
            pred = model(input_ids.to(device))
            loss = F.mse_loss(pred, stability.to(device))
            total_loss += loss.item()
    return total_loss / len(dataloader)
            
    

In [14]:
optimizer = optim.Adam(model.parameters(), lr=LR)

train_loss = evaluate(model, train_dataloader, device)
valid_loss = evaluate(model, val_dataloader, device)

print(f'Initial Train Loss: {train_loss:.3f}')
print(f'Initial Valid Loss: {valid_loss:.3f}')

for epoch in range(N_EPOCHS):

    train_loss = train(model, train_dataloader, optimizer, device)
    valid_loss = evaluate(model, val_dataloader, device)
    

    print(f'\tTrain Loss: {train_loss:.3f}')
    print(f'\tValid Loss: {valid_loss:.3f}')


  token_ids = torch.from_numpy(pad_sequences(np.array(token_ids), 0))


Initial Train Loss: 0.329
Initial Valid Loss: 0.466
	Train Loss: 0.207
	Valid Loss: 0.462
	Train Loss: 0.168
	Valid Loss: 0.338
	Train Loss: 0.141
	Valid Loss: 0.341
	Train Loss: 0.119
	Valid Loss: 0.372
	Train Loss: 0.104
	Valid Loss: 0.159
	Train Loss: 0.083
	Valid Loss: 0.171
	Train Loss: 0.067
	Valid Loss: 0.150
	Train Loss: 0.057
	Valid Loss: 0.178
	Train Loss: 0.052
	Valid Loss: 0.267
	Train Loss: 0.046
	Valid Loss: 0.202


### Predictions and Analysis

In [15]:

from torcheval.metrics.functional import r2_score
from torchmetrics import SpearmanCorrCoef


def check_metrices(pred, gold):
    model.eval()
    gold = gold
    mse = F.mse_loss(pred, gold) 
    rmse = torch.sqrt(mse)
    r_square = r2_score(pred, gold)
    cal_spearman = SpearmanCorrCoef()
    spearman = cal_spearman(pred[:,0], gold[:,0])
    print(f'mse:{mse}, rmse:{rmse}, r_square:{r_square}, spearman:{spearman}')
    


# predict the whole dataset

def make_predictions(model, dataloader):

    with torch.no_grad():
        model.eval()
        model.cpu()
        predictions = torch.tensor([])
        golds = torch.tensor([])
        for databatch in dataloader:
            input_ids, input_masks, stability = databatch['input_ids'], databatch['input_mask'], databatch['targets']
            pred = model(input_ids)
            predictions = torch.cat((predictions, pred), dim=0)
            golds = torch.cat((golds, stability), dim=0)
    return predictions, golds
            

In [16]:
#get predictions 

train_preds, train_golds = make_predictions(model, train_dataloader)
val_preds, val_golds = make_predictions(model, val_dataloader)
test_preds, test_golds = make_predictions(model, test_dataloader)
            
        

  token_ids = torch.from_numpy(pad_sequences(np.array(token_ids), 0))


In [17]:
#print the metrices


print(f'train metrices: -----------------------------------------------------------')

print(check_metrices(train_preds, train_golds))

print(f'validation metrices: -----------------------------------------------------------')

print(check_metrices(val_preds, val_golds))

print(f'test metrices: -----------------------------------------------------------')

print(check_metrices(test_preds, test_golds))


train metrices: -----------------------------------------------------------




mse:0.11966745555400848, rmse:0.3459298312664032, r_square:0.6267509460449219, spearman:0.860129177570343
None
validation metrices: -----------------------------------------------------------
mse:0.20187589526176453, rmse:0.44930601119995117, r_square:0.5303531885147095, spearman:0.7541908621788025
None
test metrices: -----------------------------------------------------------
mse:0.250195175409317, rmse:0.5001951456069946, r_square:-0.49777042865753174, spearman:0.6967188715934753
None
