In [1]:
import torch
import esm
import numpy as np

def get_esm_embeddings(sequences):
    model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    model.eval()

    data = [('protein'+str(i), seq) for i, seq in enumerate(sequences)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)

    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=False)

    token_embeddings = results["representations"][33] 

    sequence_embeddings = []
    for i in range(len(sequences)):
        seq_len = len(sequences[i])
        seq_emb = token_embeddings[i, 1:1+seq_len]
        sequence_embeddings.append(seq_emb.cpu().numpy())
    return sequence_embeddings

def embed_and_pool(sequences):
    seq_emb_list = get_esm_embeddings(sequences)
    pooled_list = []
    for emb in seq_emb_list:
        pooled_vec = emb.mean(axis=0)  
        pooled_list.append(pooled_vec)
    return np.array(pooled_list)  

In [2]:
from torch.utils.data import Dataset, DataLoader

class ProteinDataset(Dataset):
    def __init__(self, X, y=None):
        if isinstance(X, np.ndarray):
            self.X = torch.from_numpy(X).float()
        else:
            self.X = X.float()
        
        self.y = None
        if y is not None:
            if isinstance(y, np.ndarray):
                self.y = torch.from_numpy(y).long()  
            else:
                self.y = y.long()
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        if self.y is not None:
            return self.X[idx], self.y[idx]
        else:
            return self.X[idx]

In [3]:
import torch.nn as nn

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims=[512,256,128,64], num_classes=2):
        super(MLP, self).__init__()
        layers = []
        prev_dim = input_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev_dim, h))
            layers.append(nn.ReLU())
            prev_dim = h
        layers.append(nn.Linear(prev_dim, num_classes)) 
        self.network = nn.Sequential(*layers)
    
    def forward(self, x):
        return self.network(x)

In [4]:
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import roc_auc_score

def cross_val_score_mlp(X, y, 
                        hidden_dims,
                        lr,
                        batch_size,
                        device,
                        n_splits=5,
                        max_epochs=30,
                        patience=5):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    auc_scores = []
    
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X, y)):
        X_train, X_val = X[train_idx], X[val_idx]
        y_train, y_val = y[train_idx], y[val_idx]
        
        train_dataset = ProteinDataset(X_train, y_train)
        val_dataset = ProteinDataset(X_val, y_val)
        
        class_sample_counts = np.bincount(y_train)  
        weights_per_class = 1.0 / (class_sample_counts + 1e-8)
        samples_weight = weights_per_class[y_train]  
        
        sampler = WeightedRandomSampler(
            weights=samples_weight, 
            num_samples=len(samples_weight), 
            replacement=True
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        input_dim = X.shape[1]
        model = MLP(input_dim, hidden_dims=hidden_dims, num_classes=2).to(device)
        
        criterion = nn.CrossEntropyLoss() 
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        best_auc = 0.0
        best_state = None
        epochs_no_improve = 0
        
        for epoch in range(max_epochs):
            model.train()
            for batch_x, batch_y in train_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                
                optimizer.zero_grad()
                outputs = model(batch_x)
                loss = criterion(outputs, batch_y)
                loss.backward()
                optimizer.step()
            
            model.eval()
            val_probs = []
            val_targets = []
            with torch.no_grad():
                for batch_x, batch_y in val_loader:
                    batch_x = batch_x.to(device)
                    outputs = model(batch_x)
                    probs = torch.softmax(outputs, dim=1)[:,1]  
                    val_probs.extend(probs.cpu().numpy())
                    val_targets.extend(batch_y.numpy())
            
            val_auc = roc_auc_score(val_targets, val_probs)
            
            if val_auc > best_auc:
                best_auc = val_auc
                best_state = model.state_dict()
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
            
            if epochs_no_improve >= patience:
                break
        
        auc_scores.append(best_auc)

    return np.mean(auc_scores)

In [5]:
import optuna

def objective(trial):
    hidden_dims = []
    for i in range(4):
        hidden_dim = trial.suggest_int(f"hidden_dim_{i}", 64, 512, step=64)
        hidden_dims.append(hidden_dim)
    
    lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
    batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
    
    mean_auc = cross_val_score_mlp(
        X_trainval, y_trainval,
        hidden_dims=hidden_dims,
        lr=lr,
        batch_size=batch_size,
        device=device,
        n_splits=5,
        max_epochs=30,   
        patience=5
    )
    
    return mean_auc

In [6]:
import pandas as pd
import math
def predict_on_csv_in_chunks(csv_file, model_file, device, output_csv="new_data_with_pred.csv", chunk_size=100):

    df_new = pd.read_csv(csv_file)
    seqs_new = df_new["orf_sequence"].tolist()
    model = MLP(input_dim=1280, hidden_dims=[128,192,192,64], num_classes=2).to(device)
    model.load_state_dict(torch.load(model_file, map_location=device))
    model.eval()
    
    all_probs = []
    

    num_seqs = len(seqs_new)
    num_chunks = math.ceil(num_seqs / chunk_size)
    
    print(f"Total sequences: {num_seqs}. We will process {num_chunks} chunks (chunk_size={chunk_size}).")
    
    idx_start = 0
    for chunk_i in range(num_chunks):
        idx_end = min(idx_start + chunk_size, num_seqs)
        sub_seqs = seqs_new[idx_start : idx_end]
        
        X_new_sub = embed_and_pool(sub_seqs)  
        X_tensor = torch.from_numpy(X_new_sub).float().to(device)
        
        with torch.no_grad():
            outputs = model(X_tensor)                       
            probs = torch.softmax(outputs, dim=1)[:, 1]     
            all_probs_sub = probs.cpu().numpy().tolist()    
        
        all_probs.extend(all_probs_sub)
        
        idx_start = idx_end
        print(f"Chunk {chunk_i+1}/{num_chunks} done. Current total predictions={len(all_probs)}")
    

    assert len(all_probs) == len(df_new), 
    
    df_new["pred_score"] = all_probs
    df_new.to_csv(output_csv, index=False)
    print(f"Saved to {output_csv}. Final predictions: {len(all_probs)}")

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
predict_on_csv_in_chunks(
    csv_file="../microprotein_db_seq.csv",
    model_file="best_mlp.pth",
    device=device,
    output_csv="../microprotein_db_predicted_stability.csv",
    chunk_size=200
)

Total sequences: 10153. We will process 51 chunks (chunk_size=200).
Chunk 1/51 done. Current total predictions=200
Chunk 2/51 done. Current total predictions=400
Chunk 3/51 done. Current total predictions=600
Chunk 4/51 done. Current total predictions=800
Chunk 5/51 done. Current total predictions=1000
Chunk 6/51 done. Current total predictions=1200
Chunk 7/51 done. Current total predictions=1400
Chunk 8/51 done. Current total predictions=1600
Chunk 9/51 done. Current total predictions=1800
Chunk 10/51 done. Current total predictions=2000
Chunk 11/51 done. Current total predictions=2200
Chunk 12/51 done. Current total predictions=2400
Chunk 13/51 done. Current total predictions=2600
Chunk 14/51 done. Current total predictions=2800
Chunk 15/51 done. Current total predictions=3000
Chunk 16/51 done. Current total predictions=3200
Chunk 17/51 done. Current total predictions=3400
Chunk 18/51 done. Current total predictions=3600
Chunk 19/51 done. Current total predictions=3800
Chunk 20/51 do