In [1]:
from transformers import AutoTokenizer, AutoModel
import torch

In [2]:
pip install transformers

[0mNote: you may need to restart the kernel to use updated packages.


In [12]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

import gc

In [13]:
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5

class ESMForSingleMutationPosOuter(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.fc1 = nn.Linear(self.model.config.hidden_size * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, sequence1, sequence2, pos):
        # Assuming sequence1 and sequence2 are tensors of token ids
        # We don't need to use the tokenizer
        
        output1 = self.model(sequence1)
        output2 = self.model(sequence2)

        outputs1_pos = output1.last_hidden_state[:, pos + 1]
        outputs2_pos = output2.last_hidden_state[:, pos + 1]

        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits



In [14]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, df, tokenizer_name):
        self.df = df
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def __getitem__(self, idx):
        sequence1 = ''.join(self.df.iloc[idx]['wt_seq'])[:1022]
        sequence2 = ''.join(self.df.iloc[idx]['mut_seq'])[:1022]
        pos = self.df.iloc[idx]['mutation_pos']
        
        token_ids1 = self.tokenizer(sequence1, return_tensors='pt', truncation=True, padding='max_length', max_length=1022)['input_ids']
        token_ids2 = self.tokenizer(sequence2, return_tensors='pt', truncation=True, padding='max_length', max_length=1022)['input_ids']
        
        return token_ids1, token_ids2, pos, torch.unsqueeze(torch.FloatTensor([self.df.iloc[idx]['ddg']]), 0)

    def __len__(self):
        return len(self.df)


In [15]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    device = 'cuda:0'
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        pos=pos.to(device)
        labels=labels.to(device)
        logits = model(sequence1 = input_ids1, sequence2 = input_ids2, pos = pos).to(device)
        loss = torch.nn.functional.mse_loss(logits, labels).to(device)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [16]:
import os

# Define directory
directory = 'weights/'

# Create directory if it does not exist
if not os.path.exists(directory):
    os.makedirs(directory)

# Now you can


In [18]:
lr = 1e-5
EPOCHS = 2
device = 'cuda:0'

models = ['ESMForSingleMutationPosOuter']



full_df = pd.read_csv('https://storage.googleapis.com/indaba-data/train/train.csv')


# Filter the DataFrame to keep only rows with positive 'ddg' values
df_positive_ddg = full_df[full_df['ddg'] < 0]

# Shuffle the DataFrame
df_positive_ddg_shuffled = df_positive_ddg.sample(frac=1, random_state=42)

# Select the first 70K rows
df_selected = df_positive_ddg_shuffled.head(70000)




# Filter the DataFrame to keep only rows with positive 'ddg' values
df_positive_ddg = full_df[full_df['ddg'] < 0]

df_positive_ddg=pd.merge(df_positive_ddg, df_selected)

# Shuffle the DataFrame
df_positive_ddg_shuffled = df_positive_ddg.sample(frac=1, random_state=42)

# Select the first 70K rows
df_selected = df_positive_ddg_shuffled.head(70000)





full_df = full_df[full_df['ddg'] > 0]

preds = {n:[] for n in models} 
true = [None]*5

for model_name in models:
    model_class = globals()[model_name]
    print(f'Training model {model_name}')
    train_df = full_df
    train_ds = ProteinDataset(train_df,'facebook/esm2_t6_8M_UR50D')
        
    model = model_class('facebook/esm2_t6_8M_UR50D')                        
    model.to(device) 
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    training_loader = DataLoader(train_ds, batch_size=128, num_workers = 2, shuffle = True)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr, steps_per_epoch=len(training_loader), epochs=EPOCHS)
        
    for epoch in range(EPOCHS):
        train(epoch)
         
    model.to('cpu')
    
    torch.save(model, 'weights/' + model_name)
    
    del model

Training model ESMForSingleMutationPosOuter


Some weights of the model checkpoint at facebook/esm2_t6_8M_UR50D were not used when initializing EsmModel: ['lm_head.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing EsmModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.weight', 'esm.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  lo

Training loss epoch: 0.14196304729686374


  loss = torch.nn.functional.mse_loss(logits, labels).to(device)


Training loss epoch: 0.10878333113098566


In [33]:
torch.cuda.is_available()

True

In [42]:
full_df

Unnamed: 0,ID,pdb_id,mutation,wt_aa,mutation_pos,mut_aa,wt_seq,mut_seq,ddg
0,0,1GYZ,W1Q,W,1,Q,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,QIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,0.228775
1,1,1GYZ,W1E,W,1,E,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,EIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,0.496896
2,2,1GYZ,W1N,W,1,N,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,NIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,0.163002
3,3,1GYZ,W1H,W,1,H,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,HIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,0.209013
4,4,1GYZ,W1D,W,1,D,WIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,DIARINAAVRAYGLNYSTFINGLKKAGIELDRKILADMAVRDPQAF...,0.407602
...,...,...,...,...,...,...,...,...,...
339753,339753,r7-562-TrROS-Hall,K47I,K,47,I,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKV,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEIV,0.659345
339756,339756,r7-562-TrROS-Hall,K47F,K,47,F,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKV,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEFV,0.077370
339757,339757,r7-562-TrROS-Hall,K47P,K,47,P,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKV,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEPV,0.141990
339758,339758,r7-562-TrROS-Hall,K47C,K,47,C,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIEKV,MKKYKITVYDEKTGEKHTIEIEMSEEELEELAKKLAEKHNVKVRIECV,0.148234
