In [2]:
import numpy as np
import pandas as pd
from scipy.optimize import linear_sum_assignment

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

from transformers import AutoTokenizer, EsmForMaskedLM

from tqdm import tqdm

## Load PiNUI human data

In [3]:
data = pd.read_csv("./data/PiNUI-human.csv")
data.head()

Unnamed: 0,seqA,seqB,interaction
0,MKRRASDRGAGETSARAKALGSGISGNNAKRAGPFILGPRLGNSPV...,MAASAARGAAALRRSINQPVAFVRRIPWTAASSQLKEHFAQFGHVR...,1
1,MEAPSGSEPGGDGAGDCAHPDPRAPGAAAPSSGPGPCAAARESERQ...,MKLFHTADWHLGKLVHGVYMTEDQKIVLDQFVQAVEEEKPDAVIIA...,1
2,MDQNSVPEKAQNEADTNNADRFFRSHSSPPHHRPGHSRALHHYELH...,MTHCCSPCCQPTCCRTTCWQPTTVTTCSSTPCCQPSCCVSSCCQPC...,1
3,MFADLDYDIEEDKLGIPTVPGKVTLQKDAQNLIGISIGGGAQYCPC...,MARTLRPSPLCPGGGKAQLSSASLLGAGLLLQPPTPPPLLLLLFPL...,1
4,MAEGNHRKKPLKVLESLGKDFLTGVLDNLVEQNVLNWKEEEKKKYY...,MASADSRRVADGGGAGGTFQPYLDTLRQELQQTDPTLLSVVVAVLA...,1


In [4]:
# Get the sequences
all_proteins = list(set(data['seqA']).union(set(data['seqB'])))
N = len(all_proteins)
N

30263

## Dataset 

In [5]:
class PinuiDataset(Dataset):
    def __init__(self, data):
        self.num_pairs = len(data)
        self.seqA = list(data['seqA'])
        self.seqB = list(data['seqB'])
        self.interactions = list(data['interaction'])
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
        self.model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
        self.tokenized_pairs = []
        
        for seq1, seq2 in zip(self.seqA, self.seqB):
            pair_input = seq1 + " [SEP] " + seq2
            tokens = self.tokenizer(pair_input, return_tensors="pt", truncation=True, padding='max_length', max_length=1022)
            self.tokenized_pairs.append(tokens)

    def __len__(self):
        return len(self.interactions)
    
    def __getitem__(self, idx):
        
        tokens = self.tokenized_pairs[idx] # Esm embedings 
        interaction = torch.tensor(self.interactions[idx]) # Interaction

        return tokens, interaction

In [6]:
class MLP(nn.Module):
    def __init__(self, input_channels, dropout=0.2):
        super().__init__()
        self.l1 = nn.Linear(input_channels, 1024)
        self.l2 = nn.Linear(1024, 256)
        self.l3 = nn.Linear(256, 1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.relu(self.l1(x))
        out = torch.relu(self.l2(out))
        out = self.dropout(out)
        out = self.l3(out)
        return out

In [7]:
def train(model, esm_model, train_loader, test_loader, optimizer, criterion, epochs, device):

    model.to(device)

    for epoch in tqdm(range(epochs)):

        # Training
        model.train()
        total_train_loss = 0.0
        for features, targets in train_loader:
            
            input_ids = features['input_ids'].squeeze(1).to(device)
            attention_mask = features['attention_mask'].squeeze(1).to(device)
            esm_outputs = esm_model(input_ids, attention_mask=attention_mask)
            embeddings = esm_outputs.last_hidden_state.mean(dim=1)

            targets = targets.to(device).float()

            optimizer.zero_grad()
            outputs = model(embeddings)
            loss = criterion(outputs.squeeze(), targets)
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()

        # Evaluation
        model.eval()
        total_eval_loss = 0.0
        with torch.no_grad():
            for features, targets in test_loader:
                input_ids = features['input_ids'].squeeze(1).to(device)
                attention_mask = features['attention_mask'].squeeze(1).to(device)
                esm_outputs = esm_model(input_ids, attention_mask=attention_mask)
                embeddings = esm_outputs.last_hidden_state.mean(dim=1)

                targets = targets.to(device).float()

                outputs = model(embeddings)
                loss = criterion(outputs.squeeze(), targets)
                total_eval_loss += loss.item()


        avg_train_loss = total_train_loss / len(train_loader)
        avg_eval_loss = total_eval_loss / len(test_loader)
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Eval Loss: {avg_eval_loss:.4f}")


In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters
epochs = 100
hidden_layers = 3
batch_size = 32

learning_rate = 1e-4
weight_decay = 1e-2


print("Loading data...")
data = pd.read_csv("./data/PiNUI-human.csv")
data = data[:5000]

data_set = PinuiDataset(data)
train_len = int(len(data_set)*0.7)
train_set, test_set = random_split(data_set, [train_len, len(data_set)-train_len])

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
print("Data Loaded.")

Loading data...
Data Loaded.


In [11]:
from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
input_channels = esm_model.config.hidden_size 
print(esm_model.config.hidden_size)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


320


In [12]:
output_channels = 1

model = MLP(input_channels)
# Loss function and Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
esm_model = EsmForMaskedLM.from_pretrained("facebook/esm2_t6_8M_UR50D")
train(model, esm_model, train_loader, test_loader, optimizer, criterion, epochs, device)

  0%|          | 0/100 [00:10<?, ?it/s]


KeyboardInterrupt: 

In [None]:
import torch
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
esm_model = AutoModel.from_pretrained("facebook/esm2_t6_8M_UR50D")
print(esm_model.config.hidden_size)

dummy_sequence = "MYPROTEINSEQUENCE"
encoded = tokenizer(dummy_sequence, return_tensors="pt", truncation=True, max_length=1022)
with torch.no_grad():
    outputs = esm_model(**encoded)
print(outputs.last_hidden_state.shape)  #  (1, sequence_length, hidden_size)

pooled = outputs.last_hidden_state.mean(dim=1)  # Now shape (1, hidden_size)
print(pooled.shape)

Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


320
torch.Size([1, 19, 320])
torch.Size([1, 320])


## GNN

## Hierarchical graph learning for PPI