In [2]:
import os,sys,re
import argparse, json
import copy
import random
import pickle
import math
import torch
from torch import nn
import numpy as np
import pandas as pd
from torch.utils.data import Dataset
from tqdm import tqdm
#from tqdm.notebook import tqdm
from Bio.PDB.PDBParser import PDBParser
from Bio.PDB.Polypeptide import one_to_index
from Bio.PDB import Selection
from Bio import SeqIO
from Bio.PDB.Residue import Residue
from easydict import EasyDict
import enum
sys.path.append('/data/cb/scratch/varun/esm-multimer/esm-multimer/')
import esm, gzip
from Bio import SeqIO
from esm.model.esm2 import ESM2
from collections import OrderedDict
from sklearn.metrics import mean_squared_error
import scipy.stats

In [9]:
X_train = torch.load(f'embeddings/nofreeze_filtered50_70000_Intra2_embeddings.pt').numpy()

In [10]:
X_train.shape

(52048, 2560)

In [6]:
163019*2560 

417328640

## Data processing

In [2]:
fasta_dictionary = SeqIO.to_dict(SeqIO.parse('human_swissprot_oneliner.fasta', "fasta"))

def make_new_dfs(data_name):
    data_df_neg = pd.read_csv(f'{data_name}_neg_rr.txt', sep=' ', header=None)
    data_df_pos = pd.read_csv(f'{data_name}_pos_rr.txt', sep=' ', header=None)
    
    data_df = pd.concat([data_df_neg, data_df_pos], ignore_index=True)
    labels = [0]*len(data_df_neg) + [1]*len(data_df_pos)
    
    seq1 = []
    seq2 = []
    new_labels = []
    for index, row in data_df.iterrows():
        if row[0] in fasta_dictionary and row[1] in fasta_dictionary:
            seq1.append(str(fasta_dictionary[row[0]].seq))
            seq2.append(str(fasta_dictionary[row[1]].seq))
            new_labels.append(labels[index])
    seq_df = pd.DataFrame({'seq1': seq1, 'seq2': seq2, 'labels':new_labels})
    seq_df.to_csv(f'{data_name}_seqs.csv', index=False)
    print(f'{len(seq_df)} out of {len(data_df)} data points found')

make_new_dfs('Intra0')
make_new_dfs('Intra1')
make_new_dfs('Intra2')

59260 out of 59260 data points found
163019 out of 163192 data points found
52048 out of 52048 data points found


## Get embeddings

In [6]:
class PPIDataset(Dataset):
    def __init__(
        self, 
        data_name, 
    ):
        super().__init__()
        self.data_df = pd.read_csv(f'{data_name}_seqs.csv')
        
    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, index):
        row = self.data_df.iloc[index]
        s0, s1 = row["seq1"], row["seq2"]
        label = int(row["labels"])
        
        return s0, s1, label

In [4]:
class PPICollateFn:
    
    def __init__(self, truncation_seq_length=None):
        self.alphabet = esm.data.Alphabet.from_architecture("ESM-1b")
        self.truncation_seq_length = truncation_seq_length

    def __call__(self, batches):
        batch_size = len(batches)
        heavy_chain, light_chain, labels = zip(*batches)
        
        chains = [self.convert(c) for c in [heavy_chain, light_chain]]
        chain_ids = [torch.ones(c.shape, dtype=torch.int32) * i for i, c in enumerate(chains)]
        chains = torch.cat(chains, -1)
        chain_ids = torch.cat(chain_ids, -1)
        labels = torch.from_numpy(np.stack(labels, 0))
        
        return chains, chain_ids, labels

    def convert(self, seq_str_list):
        batch_size = len(seq_str_list)
        seq_encoded_list = [self.alphabet.encode('<cls>' + seq_str.replace('J', 'L') + '<eos>') for seq_str in seq_str_list]
        if self.truncation_seq_length:
            for i in range(batch_size):
                seq = seq_encoded_list[i]
                if len(seq) > self.truncation_seq_length:
                    start = random.randint(0, len(seq) - self.truncation_seq_length + 1)
                    seq_encoded_list[i] = seq[start:start+self.truncation_seq_length]
        max_len = max(len(seq_encoded) for seq_encoded in seq_encoded_list)
        if self.truncation_seq_length:
            assert max_len <= self.truncation_seq_length
        tokens = torch.empty((batch_size, max_len), dtype=torch.int64)
        tokens.fill_(self.alphabet.padding_idx)
        
        for i, seq_encoded in enumerate(seq_encoded_list):
            seq = torch.tensor(seq_encoded, dtype=torch.int64)
            tokens[i,:len(seq_encoded)] = seq
        return tokens

In [5]:
def upgrade_state_dict(state_dict):
    """Removes prefixes 'model.encoder.sentence_encoder.' and 'model.encoder.'."""
    prefixes = ["encoder.sentence_encoder.", "encoder."]
    pattern = re.compile("^" + "|".join(prefixes))
    state_dict = {pattern.sub("", name): param for name, param in state_dict.items()}
    return state_dict

class PPIWrapper(nn.Module):
    def __init__(self, cfg, checkpoint_path, freeze_percent=0.0, use_multimer=True, device='cuda:0'):
        super().__init__()
        self.cfg = cfg
        self.model = ESM2(
            num_layers=cfg.encoder_layers,
            embed_dim=cfg.encoder_embed_dim,
            attention_heads=cfg.encoder_attention_heads,
            token_dropout=cfg.token_dropout,
            use_multimer = use_multimer,
        )
        checkpoint = torch.load(checkpoint_path, map_location=device)

        if use_multimer:
            # remove 'model.' in keys
            new_checkpoint = OrderedDict((key.replace('model.', ''), value) for key, value in checkpoint['state_dict'].items())
            self.model.load_state_dict(new_checkpoint)
        else:
            new_checkpoint = upgrade_state_dict(checkpoint['model'])
            self.model.load_state_dict(new_checkpoint)
        total_layers = 33
        for name, param in self.model.named_parameters():
            if 'embed_tokens.weight' in name or '_norm_after' in name or 'lm_head' in name:
                param.requires_grad = False
            else:
                layer_num = name.split('.')[1]
                if int(layer_num) <= math.floor(total_layers*freeze_percent):
                    param.requires_grad = False


    def forward(self, chains, chain_ids):
        mask = (~chains.eq(self.model.cls_idx)) & (~chains.eq(self.model.eos_idx)) & (~chains.eq(self.model.padding_idx))
        chain_out = self.model(chains, chain_ids, repr_layers=[33])["representations"][33]
        mask_expanded = mask.unsqueeze(-1).expand_as(chain_out)
        masked_chain_out = chain_out * mask_expanded
        sum_masked = masked_chain_out.sum(dim=1)
        mask_counts = mask.sum(dim=1, keepdim=True).float()  # Convert to float for division
        mean_chain_out = sum_masked / mask_counts
        return mean_chain_out

In [6]:
@torch.no_grad()
def evaluate(model, loader, name):

    device = 'cuda:7'
    model.to(device)

    full_embs = []

    for step, eval_batch in enumerate(tqdm(loader)):
        
        chains, chain_ids, target = eval_batch
        chains = chains.to(device)
        chain_ids = chain_ids.to(device)
        target = target.to(device)

        embs = model(chains, chain_ids)

        full_embs.append(embs.squeeze(-1).detach().cpu())

    full_embs = torch.cat(full_embs).ravel()
    torch.save(full_embs, f'./embeddings/{name}_embeddings.pt')

## Train MLP

In [2]:
from sklearn.metrics import accuracy_score, precision_recall_curve, auc, f1_score, precision_score, recall_score, confusion_matrix, matthews_corrcoef

In [3]:
class PPI_MLP_Dataset(Dataset):
    def __init__(
        self, 
        x, y
    ):
        super().__init__()
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.y)

    def __getitem__(self, index):
        return self.x[index], self.y[index]


class SimpleMLP(nn.Module):
    def __init__(self, input_size, output_size, num_layers, hidden_size, dropout, activation):
        super(SimpleMLP, self).__init__()
        self.layers = nn.ModuleList()
        
        self.layers.append(nn.Linear(input_size, hidden_size))
        
        for _ in range(1, num_layers):
            self.layers.append(nn.Linear(hidden_size, hidden_size))
        
        self.layers.append(nn.Linear(hidden_size, output_size))
        
        self.dropout = nn.Dropout(dropout)
        
        if activation == "relu": self.activation = torch.nn.ReLU()
        elif activation == "silu": self.activation = torch.nn.SiLU()
        elif activation == "identity": self.activation = torch.nn.Identity()
        
    def forward(self, x):
        # Apply layers with activation and dropout
        for layer in self.layers[:-1]:
            x = layer(x)
            x = self.activation(x)
            x = self.dropout(x)
        # Output layer, no activation and dropout
        x = self.layers[-1](x)
        return x

In [7]:
y_train = PPIDataset('Intra1').data_df['labels'].tolist()
y_val = PPIDataset('Intra0').data_df['labels'].tolist()
y_test = PPIDataset('Intra2').data_df['labels'].tolist()

In [8]:
X_train = torch.load('embeddings/nofreeze_Intra1_embeddings.pt').numpy().reshape(len(y_train), -1)
X_valid = torch.load('embeddings/nofreeze_Intra0_embeddings.pt').numpy().reshape(len(y_val), -1)
X_test = torch.load('embeddings/nofreeze_Intra2_embeddings.pt').numpy().reshape(len(y_test), -1)

In [25]:
def normalize_train_test(X_train, X_valid, X_test):
    scaler = StandardScaler()
    scaler.fit(X_train)
    X_train = scaler.transform(X_train)
    X_valid = scaler.transform(X_valid)
    X_test = scaler.transform(X_test)
    return X_train, X_valid, X_test

from sklearn.preprocessing import StandardScaler
X_train_norm, X_valid_norm, X_test_norm = normalize_train_test(X_train, X_valid, X_test)

In [26]:
from sklearn.linear_model import Ridge

clf = Ridge(alpha=0.1)

In [27]:
clf.fit(X_train_norm, y_train)

Ridge(alpha=0.1)

In [28]:
preds = clf.predict(X_test)

In [33]:
classification_metrics(y_test, preds, threshold=0.5)

{'Accuracy': 0.5022287119581924,
 'AUPRC': 0.5190343066629095,
 'F1 Score': 0.023003243080171962,
 'Precision': 0.6174089068825911,
 'Recall': 0.011719950814632648,
 'Specificity': 0.9927374731017522,
 'MCC': 0.02298599623136244}

In [13]:
def classification_metrics(targets, predictions, threshold=0.5):
    # Convert probabilities to binary predictions based on a threshold
    binary_predictions = (predictions >= threshold).astype(int)
    
    # Calculate accuracy
    accuracy = accuracy_score(targets, binary_predictions)
    
    # Calculate precision, recall, and F1 score
    precision = precision_score(targets, binary_predictions)
    recall = recall_score(targets, binary_predictions)
    f1 = f1_score(targets, binary_predictions)
    mcc = matthews_corrcoef(targets, binary_predictions)
    
    # Calculate specificity (True Negative Rate)
    tn, fp, fn, tp = confusion_matrix(targets, binary_predictions).ravel()
    specificity = tn / (tn + fp)
    
    # Calculate AUPRC
    precision_vals, recall_vals, _ = precision_recall_curve(targets, predictions)
    auprc = auc(recall_vals, precision_vals)

    
    return {
        'Accuracy': accuracy,
        'AUPRC': auprc,
        'F1 Score': f1,
        'Precision': precision,
        'Recall': recall,
        'Specificity': specificity,
        'MCC': mcc
    }

@torch.no_grad()
def evaluate(model, loader, device='cuda'):

    preds = []
    targets = []

    for step, eval_batch in enumerate(tqdm(loader)):
        embs, target = eval_batch
        embs = embs.to(device)
        target = target.to(device)
        pred = model(embs).squeeze(-1)  

        pred = torch.sigmoid(pred)

        preds.append(pred.detach().cpu().numpy())
        targets.append(target.cpu().numpy())

    preds = np.concatenate(preds)
    targets = np.concatenate(targets)

    metrics_dict = classification_metrics(targets, preds)
    return metrics_dict


def train(model, train_loader, val_loader, num_epochs, device='cuda'):
    optimizer = torch.optim.AdamW(
            filter(lambda p: p.requires_grad, model.parameters()), 
            lr=1e-3)
    loss_fn = torch.nn.BCEWithLogitsLoss()

    model.to(device)

    for epoch in range(num_epochs):
        print(f'Training at epoch {epoch}')
        loss_accum = 0
        for step, train_batch in enumerate(tqdm(train_loader)):

            model.train()
            optimizer.zero_grad()

            embs, target = train_batch
            embs = embs.to(device)
            target = target.to(device)

            pred = model(embs).squeeze(-1)   
            loss = loss_fn(pred, target.float())

            loss.backward()
            optimizer.step()
            loss_accum += loss.detach().cpu().item()
        print(f'Loss at end of epoch {epoch}: {loss_accum/(step+1)}')

        print(f'Evaluating at epoch {epoch}')
        metrics_dict = evaluate(model, val_loader)
        print(metrics_dict)


In [103]:
train_dataset = PPI_MLP_Dataset(X_train, y_train)
val_dataset = PPI_MLP_Dataset(X_valid, y_val)
test_dataset = PPI_MLP_Dataset(X_test, y_test)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=512, shuffle=True
)
val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=512, shuffle=False
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=512, shuffle=False
)

In [104]:
model =  SimpleMLP(1280, 1, 2, 128, 0.5, "relu")

In [105]:
train(model, train_loader, test_loader, num_epochs=11)

Training at epoch 0


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 160.87it/s]


Loss at end of epoch 0: 0.6926425403935782
Evaluating at epoch 0


100%|████████████████████████████████████████| 102/102 [00:00<00:00, 318.55it/s]


{'Accuracy': 0.5256494005533354, 'AUPRC': 0.535426084216603, 'F1 Score': 0.45161147020279424, 'Precision': 0.535137126914776, 'Recall': 0.39063940977559175, 'Specificity': 0.660659391331079, 'MCC': 0.05327782073846981}
Training at epoch 1


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 172.39it/s]


Loss at end of epoch 1: 0.6895941010089504
Training at epoch 2


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 180.86it/s]


Loss at end of epoch 2: 0.6856047546601968
Training at epoch 3


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 180.65it/s]


Loss at end of epoch 3: 0.6827683809408948
Training at epoch 4


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 186.74it/s]


Loss at end of epoch 4: 0.6804738140031462
Training at epoch 5


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 182.80it/s]


Loss at end of epoch 5: 0.6783258232950791
Evaluating at epoch 5


100%|████████████████████████████████████████| 102/102 [00:00<00:00, 332.26it/s]


{'Accuracy': 0.5773708884106978, 'AUPRC': 0.6140395752187259, 'F1 Score': 0.5562526477174154, 'Precision': 0.5855098314010277, 'Recall': 0.5297802028896403, 'Specificity': 0.6249615739317553, 'MCC': 0.15544751674512236}
Training at epoch 6


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 176.54it/s]


Loss at end of epoch 6: 0.6758756245191568
Training at epoch 7


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 180.49it/s]


Loss at end of epoch 7: 0.6752051077666328
Training at epoch 8


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 167.62it/s]


Loss at end of epoch 8: 0.6732667721924737
Training at epoch 9


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 187.22it/s]


Loss at end of epoch 9: 0.6725723214657703
Training at epoch 10


100%|████████████████████████████████████████| 319/319 [00:01<00:00, 186.07it/s]


Loss at end of epoch 10: 0.6717646546124665
Evaluating at epoch 10


100%|████████████████████████████████████████| 102/102 [00:00<00:00, 242.82it/s]

{'Accuracy': 0.587573009529665, 'AUPRC': 0.624404507589362, 'F1 Score': 0.5518206113245364, 'Precision': 0.604197147037308, 'Recall': 0.5078004918536735, 'Specificity': 0.6673455272056563, 'MCC': 0.17741863626585555}





In [33]:
param_grid = {
    "activation": ["logistic", "relu", "identity"],
    "alpha": [0.0001, 0.001, 0.01],
    "learning_rate": ["adaptive"],
    "solver": ["adam"],
    "learning_rate_init": [0.001, 0.01],
    "max_iter": [1000, 2000],
    "hidden_layer_sizes": [
        (64,), (128,), (512,),
        (64, 64), (128, 128),
        (64, 64, 64),
        ],
    "early_stopping": [True],
    "validation_fraction": [0.1],
    "tol": [1e-4, 1e-5],
}

param_grid = {"max_iter": [1]}

In [34]:
valid_index = np.concatenate([-1 * np.ones(X_train_norm.shape[0]), np.zeros(X_valid_norm.shape[0])], axis = 0)

In [35]:
cv = PredefinedSplit(valid_index)

In [36]:
mlp = MLPClassifier()

In [37]:
scoring = ["accuracy", "average_precision", "f1", "recall", "precision", "roc_auc"]
refit = "average_precision"
verbose = 10
n_jobs = -1

clsf = GridSearchCV(mlp, param_grid, 
                    cv=cv, 
                    scoring=scoring, 
                    verbose=verbose, 
                    n_jobs=n_jobs, 
                    refit=refit)

In [38]:
clsf.fit(X_trainval_norm, y_trainval)

Fitting 1 folds for each of 1 candidates, totalling 1 fits




GridSearchCV(cv=PredefinedSplit(test_fold=array([-1, -1, ...,  0,  0])),
             estimator=MLPClassifier(), n_jobs=-1, param_grid={'max_iter': [1]},
             refit='average_precision',
             scoring=['accuracy', 'average_precision', 'f1', 'recall',
                      'precision', 'roc_auc'],
             verbose=10)

In [41]:
best_estimator = clsf.best_estimator_

In [42]:
y_test_pred = best_estimator.predict_proba(X_test_norm)[:, 1]

[CV 1/1; 1/1] START max_iter=1..................................................
[CV 1/1; 1/1] END max_iter=1; accuracy: (test=0.537) average_precision: (test=0.555) f1: (test=0.528) precision: (test=0.539) recall: (test=0.518) roc_auc: (test=0.554) total time=   3.0s


