In [1]:
%env TORCH_HOME=/torch_hub 

env: TORCH_HOME=/torch_hub


In [2]:
!pip install Bio
!pip install fair-esm
!pip install pandas
!pip install scikit-learn
!pp install scipy

[0m/bin/bash: pp: command not found


In [3]:
import pandas as pd
import numpy as np
import Bio
from Bio import SeqIO
import os
import torch
import math
import esm

from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import scipy
from scipy import stats

import gc

In [4]:
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F

HIDDEN_UNITS_POS_CONTACT = 5

class ESMForSingleMutationPosOuter(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.model = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.fc1 = nn.Linear(self.model.config.hidden_size * 2, HIDDEN_UNITS_POS_CONTACT)
        self.fc2 = nn.Linear(HIDDEN_UNITS_POS_CONTACT, 1)

    def forward(self, sequence1, sequence2, pos):
        # Assuming sequence1 and sequence2 are tensors of token ids
        # We don't need to use the tokenizer
        
        output1 = self.model(sequence1)
        output2 = self.model(sequence2)

        outputs1_pos = output1.last_hidden_state[:, pos + 1]
        outputs2_pos = output2.last_hidden_state[:, pos + 1]

        outputs_pos_concat = torch.cat((outputs1_pos, outputs2_pos), 2)
        fc1_outputs = F.relu(self.fc1(outputs_pos_concat))
        logits = self.fc2(fc1_outputs)
        return logits



In [5]:
from transformers import AutoTokenizer
from torch.utils.data import Dataset

class ProteinDataset(Dataset):
    def __init__(self, df, tokenizer_name):
        self.df = df
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

    def __getitem__(self, idx):
        sequence1 = ''.join(self.df.iloc[idx]['wt_seq'])[:1022]
        sequence2 = ''.join(self.df.iloc[idx]['mut_seq'])[:1022]
        pos = self.df.iloc[idx]['mutation_pos']
        
        token_ids1 = self.tokenizer(sequence1, return_tensors='pt', truncation=True, padding='max_length', max_length=1022)['input_ids']
        token_ids2 = self.tokenizer(sequence2, return_tensors='pt', truncation=True, padding='max_length', max_length=1022)['input_ids']
        
        return token_ids1, token_ids2, pos

    def __len__(self):
        return len(self.df)

In [6]:
def train(epoch):
    tr_loss, tr_accuracy = 0, 0
    nb_tr_examples, nb_tr_steps = 0, 0
    tr_preds, tr_labels = [], []
    model.train()
    device = 'cuda:0'
    for idx, batch in enumerate(training_loader):
        input_ids1, input_ids2, pos, labels = batch            
        input_ids1 = input_ids1[0].to(device)
        input_ids2 = input_ids2[0].to(device)
        pos=pos.to(device)
        labels=labels.to(device)
        logits = model(sequence1 = input_ids1, sequence2 = input_ids2, pos = pos).to(device)
        loss = torch.nn.functional.mse_loss(logits, labels).to(device)
        tr_loss += loss.item()

        nb_tr_steps += 1
        nb_tr_examples += labels.size(0)
        
        torch.nn.utils.clip_grad_norm_(
            parameters=model.parameters(), max_norm=0.1
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

    epoch_loss = tr_loss / nb_tr_steps
    print(f"Training loss epoch: {epoch_loss}")

In [7]:
test = pd.read_csv('https://storage.googleapis.com/indaba-data/test/test.csv')


In [8]:
test['mutation_pos'] = 35

In [10]:
# Create an instance of the ProteinDataset class for the test dataset
test_ds = ProteinDataset(test,'facebook/esm2_t6_8M_UR50D')

# Create a data loader for the test dataset
test_loader = DataLoader(test_ds, batch_size=128, num_workers=2, shuffle=False)

In [11]:
model = torch.load('weights/ESMForSingleMutationPosOuter')  # Replace 'weights/ESMForSingleMutationPosConcat' with the path to your trained model


In [12]:
device = 'cuda:0'
model.to(device)
    


ESMForSingleMutationPosOuter(
  (model): EsmModel(
    (embeddings): EsmEmbeddings(
      (word_embeddings): Embedding(33, 320, padding_idx=1)
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(1026, 320, padding_idx=1)
    )
    (encoder): EsmEncoder(
      (layer): ModuleList(
        (0): EsmLayer(
          (attention): EsmAttention(
            (self): EsmSelfAttention(
              (query): Linear(in_features=320, out_features=320, bias=True)
              (key): Linear(in_features=320, out_features=320, bias=True)
              (value): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
              (rotary_embeddings): RotaryEmbedding()
            )
            (output): EsmSelfOutput(
              (dense): Linear(in_features=320, out_features=320, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (LayerNorm): LayerNorm((320,), eps=1e-05, ele

In [13]:
predictions = []
for idx, batch in enumerate(test_loader):
    input_ids1, input_ids2, pos = batch            
    input_ids1 = input_ids1[0].to(device)
    input_ids2 = input_ids2[0].to(device)
    logits = model(sequence1 = input_ids1, sequence2 = input_ids2, pos = pos)
    print(logits)
    predictions.append(logits)


tensor([[[0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.3422],
         [0.

In [14]:
values_list = [tensor.tolist() for tensor in predictions]

In [15]:
pred=[]

In [16]:
for i in values_list:
    for j in i[0]:
        pred.append(j)

In [20]:
predictions=[float(i[0]) for i in pred]

In [18]:
pred[0]

[0.34216126799583435]

In [21]:
id=test['ID'].values

In [22]:
df = pd.DataFrame({'ID': id, 'ddg': predictions})

In [23]:
df.to_csv('submit.csv',index=False)

In [23]:
float(predictions[0])

-0.6735116243362427

In [24]:
predictions=[float(i) for i in predictions]

In [26]:
id=test['ID'].values

In [27]:
id

array([   0,    1,    2, ..., 2410, 2411, 2412])

In [28]:
len(id)

1907

In [29]:
len(predictions)

1907

In [30]:
df = pd.DataFrame({'ID': id, 'ddg': predictions})

In [32]:
df.to_csv('my_last_hope.csv',index=False)
