In [4]:
# optional: To save/load files
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat2 = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_test_tiny_embeddings_768.pkl')

In [6]:
import re
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple

# Tokenizer
def tokenizer(sequence: str) -> List[str]:
    sequence = re.sub(r'\s+', '', str(sequence))
    sequence = re.sub(r'[^ARNDCQEGHILKMFPSTWYVBZX]', '*', sequence)
    return list(sequence)

# Vocabulary mappings
AMINO_MAP = {
    '<pad>': 24, '*': 23, 'A': 0, 'C': 4, 'B': 20,
    'E': 6, 'D': 3, 'G': 7, 'F': 13, 'I': 9, 'H': 8,
    'K': 11, 'M': 12, 'L': 10, 'N': 2, 'Q': 5, 'P': 14,
    'S': 15, 'R': 1, 'T': 16, 'W': 17, 'V': 19, 'Y': 18,
    'X': 22, 'Z': 21
}

AMINO_MAP_REV = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K',
    'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'B', 'Z', 'X', '*', '@'
]

AMINO_MAP_REV_ = ['A','R','N','D','C','Q','E','G','H','I','L','K',
                 'M','F','P','S','T','W','Y','V','N','Q','*','*','@']

# Padding function
def pad_sequence(sequence: List[int], max_length: int, pad_type: str = "end") -> List[int]:
    pad_token = AMINO_MAP['<pad>']
    if len(sequence) > max_length:
        return sequence[:max_length]
    padding = [pad_token] * (max_length - len(sequence))
    if pad_type == "front":
        return padding + sequence
    elif pad_type == "mid":
        half = len(padding) // 2
        return padding[:half] + sequence + padding[half:]
    else:
        return sequence + padding



def encode_sequence(sequence: List[str], max_length: int, pad_type: str) -> List[int]:
        token_ids = [AMINO_MAP.get(token, AMINO_MAP['*']) for token in sequence]
        return pad_sequence(token_ids, max_length, pad_type)


def load_embedding(filename):
    if filename is None or filename.lower() == 'none':
        filename = '/kaggle/input/blosum/BLOSUM62.txt'
    
    embedding_file = open(filename, "r")
    lines = embedding_file.readlines()[7:]
    embedding_file.close()

    embedding = [[float(x) for x in l.strip().split()[1:]] for l in lines]
    embedding.append([0.0] * len(embedding[0]))

    return embedding


In [7]:
import torch
import torch.nn as nn
embedding = load_embedding(None)
num_amino = len(embedding)
embedding_dim = len(embedding[0])
nn_embedding = nn.Embedding(num_amino, embedding_dim, padding_idx=num_amino-1)
embedding_model = nn_embedding.from_pretrained(torch.FloatTensor(embedding), freeze=False)

In [None]:
embedding_model.to(device)

In [9]:
peptides = [encode_sequence(tokenizer(pep), 15, "end") for pep in dat['antigen'].tolist()]
tcrs = [encode_sequence(tokenizer(dat['cdr3_sequence']), 25, "end")for tcr in dat['cdr3_sequence'].tolist()] 

peptides_test = [encode_sequence(tokenizer(pep), 15, "end") for pep in dat2['antigen'].tolist()]
tcrs_test = [encode_sequence(tokenizer(dat['cdr3_sequence']), 25, "end")for tcr in dat2['cdr3_sequence'].tolist()] 

In [12]:
with torch.no_grad():
    pep_blosum = embedding_model(torch.tensor(peptides).to(device))
    tcr_blosum = embedding_model(torch.tensor(tcrs).to(device))
    pep_blosum_test = embedding_model(torch.tensor(peptides_test).to(device))
    tcr_blosum_test = embedding_model(torch.tensor(tcrs_test).to(device))
pep_blosum = pep_blosum.tolist()
tcr_blosum = tcr_blosum.tolist()
pep_blosum_test = pep_blosum_test.tolist()
tcr_blosum_test = tcr_blosum_test.tolist()

In [10]:
class TCRDataset(Dataset):
    def __init__(self, tcr_embeds, epi_embeds,pep_blosum, tcr_blosum, labels):
        self.tcr_embeds = tcr_embeds
        self.epi_embeds = epi_embeds
        self.pep_blosum = pep_blosum
        self.tcr_blosum = tcr_blosum
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.tcr_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.epi_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.pep_blosum[idx], dtype=torch.float32), \
               torch.tensor(self.tcr_blosum[idx], dtype=torch.float32), \
               torch.tensor(self.labels[idx], dtype=torch.float32)

In [13]:
train_dataset = TCRDataset(dat['tcr_embeds'].tolist(),
                           dat['epi_embeds'].tolist(),
                           pep_blosum,
                           tcr_blosum,
                           dat['class'].tolist())

test_dataset = TCRDataset(dat2['tcr_embeds'].tolist(),
                          dat2['epi_embeds'].tolist(),
                          pep_blosum_test,
                          tcr_blosum_test,
                          dat2['class'].tolist())

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BindingAffinityModel(nn.Module):
    def __init__(self, size, blosum_embedding_dim, lstm_hidden_dim=128, ffn_hidden_dim=256):
        super(BindingAffinityModel, self).__init__()
        self.n_dim = size
        
        # LSTM branch for BLOSUM embeddings of TCR sequences
        self.tcr_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                hidden_size=lstm_hidden_dim, 
                                num_layers=1, 
                                batch_first=True, 
                                bidirectional=True)
        
        # LSTM branch for BLOSUM embeddings of epitope sequences
        self.epitope_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                    hidden_size=lstm_hidden_dim, 
                                    num_layers=1, 
                                    batch_first=True, 
                                    bidirectional=True)
        
        # Branch A (FFN for TCR embeddings)
        self.branchA = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim * 2),
            nn.BatchNorm1d(self.n_dim * 2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        # Branch B (FFN for epitope embeddings)
        self.branchB = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim * 2),
            nn.BatchNorm1d(self.n_dim * 2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        # Combined FFN layers
        self.combined = nn.Sequential(
            nn.Linear(self.n_dim * 2 * 2 + 4 * lstm_hidden_dim, self.n_dim),
            nn.BatchNorm1d(self.n_dim),
            nn.Dropout(0.3),
            nn.SiLU(),
            nn.Linear(self.n_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, tcr_blosum_seq, epitope_blosum_seq, inputA, inputB):
        """
        Parameters:
        - tcr_blosum_seq: Tensor [batch_size, seq_len, blosum_embedding_dim] (BLOSUM TCR)
        - epitope_blosum_seq: Tensor [batch_size, seq_len, blosum_embedding_dim] (BLOSUM Epitope)
        - inputA: Tensor [batch_size, size] (TCR BERT Embedding)
        - inputB: Tensor [batch_size, size] (Epitope BERT Embedding)
        """
        # Process TCR BLOSUM sequence through LSTM
        tcr_lstm_out, _ = self.tcr_lstm(tcr_blosum_seq)  # Shape: [batch_size, seq_len, 2*lstm_hidden_dim]
        tcr_lstm_out = tcr_lstm_out[:, -1, :]  # Take last hidden state
        
        # Process Epitope BLOSUM sequence through LSTM
        epitope_lstm_out, _ = self.epitope_lstm(epitope_blosum_seq)  # Shape: [batch_size, seq_len, 2*lstm_hidden_dim]
        epitope_lstm_out = epitope_lstm_out[:, -1, :]  # Take last hidden state
        
        # Process BERT embeddings through FFN branches
        x = self.branchA(inputA)  # Shape: [batch_size, n_dim * 2]
        y = self.branchB(inputB)  # Shape: [batch_size, n_dim * 2]
        
        # Concatenate LSTM outputs with branch outputs
        combined = torch.cat((x, y, tcr_lstm_out, epitope_lstm_out), dim=1)  # Shape: [batch_size, n_dim*4 + 4*lstm_hidden_dim]
        
        # Pass through combined FFN
        z = self.combined(combined)  # Final output
        return z


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
FFnmodel = BindingAffinityModel(768, 24).to(device)
optimizer = torch.optim.Adam(FFnmodel.parameters(), lr=0.001)
criterion = nn.BCELoss()

# Data preparation (replace with actual data)
# train_dataset = TCRDataset(tcr_embeds_train, epi_embeds_train, labels_train)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# test_dataset = TCRDataset(tcr_embeds_test, epi_embeds_test, labels_test)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Training loop
for epoch in range(5):
    FFnmodel.train()
    epoch_loss = 0
    train_preds, train_labels = [], []
    
    for tcr, epi, tcr_blosum, epi_blosum, label in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        tcr, epi, label,tcr_blosum, epi_blosum = tcr.to(device), epi.to(device), label.to(device), tcr_blosum.to(device), epi_blosum.to(device)

        optimizer.zero_grad()
        outputs = FFnmodel(tcr_blosum, epi_blosum,tcr, epi).squeeze()
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        train_preds.extend(outputs.detach().cpu().numpy())
        train_labels.extend(label.cpu().numpy())

    # Metrics
    train_preds_binary = [1 if p >= 0.5 else 0 for p in train_preds]
    acc = accuracy_score(train_labels, train_preds_binary)
    prec = precision_score(train_labels, train_preds_binary)
    rec = recall_score(train_labels, train_preds_binary)
    auc = roc_auc_score(train_labels, train_preds)

    print(f"Epoch {epoch+1} - Loss: {epoch_loss/len(train_loader):.4f}, Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")

# Save the trained model
model_save_path = 'modified_catelmo_epi_weights.pth'
torch.save(FFnmodel.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
FFnmodel.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for tcr, epi, tcr_blosum, epi_blosum, label in tqdm(test_loader, desc="Testing"):
        tcr, epi, label, tcr_blosum, epi_blosum = tcr.to(device), epi.to(device), label.to(device), tcr_blosum.to(device), epi_blosum.to(device)
        outputs = FFnmodel(tcr_blosum, epi_blosum, tcr, epi).squeeze()
        test_preds.extend(outputs.cpu().numpy())
        test_labels.extend(label.cpu().numpy())

# Test metrics
test_preds_binary = [1 if p >= 0.5 else 0 for p in test_preds]
acc = accuracy_score(test_labels, test_preds_binary)
prec = precision_score(test_labels, test_preds_binary)
rec = recall_score(test_labels, test_preds_binary)
auc = roc_auc_score(test_labels, test_preds)

print(f"Test Metrics - Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")

df = pd.DataFrame({
    'True_Labels': test_labels,
    'Predicted_Scores': test_preds,
    'Predicted_Binary': test_preds_binary
})

csv_filename = 'modified_epi_test_predictions.csv'
df.to_csv(csv_filename, index=False)

print(f"Predictions and labels saved to: {csv_filename}")

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
import itertools
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score

def hyperparameter_tuning(dat, pep_blosum, tcr_blosum, 
                           param_grid={
                               'learning_rate': [0.001, 0.0001],
                               'batch_size': [32, 64],
                               'hidden_dim': [64, 128],
                               'dropout': [0.2, 0.3, 0.4]
                           },
                           n_splits=5, 
                           epochs=5):
    """
    Perform hyperparameter tuning using grid search with cross-validation
    
    Parameters:
    - dat: DataFrame containing dataset
    - pep_blosum: Peptide BLOSUM embeddings
    - tcr_blosum: TCR BLOSUM embeddings
    - param_grid: Dictionary of hyperparameters to tune
    - n_splits: Number of cross-validation splits
    - epochs: Number of training epochs
    
    Returns:
    - Best hyperparameters and corresponding performance metrics
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Generate all hyperparameter combinations
    param_combinations = [
        dict(zip(param_grid.keys(), v)) 
        for v in itertools.product(*param_grid.values())
    ]
    
    # Store results for each hyperparameter combination
    results = []
    
    for params in param_combinations:
        print(f"\n--- Hyperparameter Configuration ---")
        for k, v in params.items():
            print(f"{k}: {v}")
        
        # Stratified K-Fold
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
        
        # Fold-wise performance metrics
        fold_metrics = {
            'accuracies': [],
            'precisions': [],
            'recalls': [],
            'aucs': []
        }
        
        labels = dat['class'].tolist()
        
        for fold, (train_index, val_index) in enumerate(skf.split(dat, labels), 1):
            print(f"\nFold {fold}")
            
            # Prepare datasets for this fold
            train_df = dat.iloc[train_index]
            val_df = dat.iloc[val_index]
            
            train_pep = [pep_blosum[i] for i in train_index]
            val_pep = [pep_blosum[i] for i in val_index]
            train_tcr = [tcr_blosum[i] for i in train_index]
            val_tcr = [tcr_blosum[i] for i in val_index]
            
            # Create datasets
            train_dataset = TCRDataset(
                train_df['tcr_embeds'].tolist(),
                train_df['epi_embeds'].tolist(),
                train_pep,
                train_tcr,
                train_df['class'].tolist()
            )
            
            val_dataset = TCRDataset(
                val_df['tcr_embeds'].tolist(),
                val_df['epi_embeds'].tolist(),
                val_pep,
                val_tcr,
                val_df['class'].tolist()
            )
            
            # DataLoaders
            train_loader = DataLoader(train_dataset, 
                                      batch_size=params['batch_size'], 
                                      shuffle=True)
            val_loader = DataLoader(val_dataset, 
                                    batch_size=params['batch_size'], 
                                    shuffle=False)
            
            # Modify model to include custom hidden dimension and dropout
            class TunedBindingAffinityModel(nn.Module):
                def __init__(self, size, blosum_embedding_dim, 
                             lstm_hidden_dim=128, 
                             ffn_hidden_dim=256, 
                             dropout_rate=0.3):
                    super(TunedBindingAffinityModel, self).__init__()
                    self.n_dim = size
                    
                    # LSTM branches (similar to original model)
                    self.tcr_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                            hidden_size=lstm_hidden_dim, 
                                            num_layers=1, 
                                            batch_first=True, 
                                            bidirectional=True)
                    
                    self.epitope_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                                hidden_size=lstm_hidden_dim, 
                                                num_layers=1, 
                                                batch_first=True, 
                                                bidirectional=True)
                    
                    # Branch A with customized dropout
                    self.branchA = nn.Sequential(
                        nn.Linear(self.n_dim, self.n_dim * 2),
                        nn.BatchNorm1d(self.n_dim * 2),
                        nn.Dropout(dropout_rate),
                        nn.SiLU()
                    )
                    
                    # Branch B with customized dropout
                    self.branchB = nn.Sequential(
                        nn.Linear(self.n_dim, self.n_dim * 2),
                        nn.BatchNorm1d(self.n_dim * 2),
                        nn.Dropout(dropout_rate),
                        nn.SiLU()
                    )
                    
                    # Combined FFN layers with potential modifications
                    self.combined = nn.Sequential(
                        nn.Linear(self.n_dim * 2 * 2 + 4 * lstm_hidden_dim, self.n_dim),
                        nn.BatchNorm1d(self.n_dim),
                        nn.Dropout(dropout_rate),
                        nn.SiLU(),
                        nn.Linear(self.n_dim, 1),
                        nn.Sigmoid()
                    )
                
                def forward(self, tcr_blosum_seq, epitope_blosum_seq, inputA, inputB):
                    # Same forward method as original model
                    tcr_lstm_out, _ = self.tcr_lstm(tcr_blosum_seq)
                    tcr_lstm_out = tcr_lstm_out[:, -1, :]
                    
                    epitope_lstm_out, _ = self.epitope_lstm(epitope_blosum_seq)
                    epitope_lstm_out = epitope_lstm_out[:, -1, :]
                    
                    x = self.branchA(inputA)
                    y = self.branchB(inputB)
                    
                    combined = torch.cat((x, y, tcr_lstm_out, epitope_lstm_out), dim=1)
                    
                    z = self.combined(combined)
                    return z
            
            # Initialize model with tuned parameters
            model = TunedBindingAffinityModel(
                size=768, 
                blosum_embedding_dim=24, 
                lstm_hidden_dim=params['lstm_hidden_dim'],
                dropout_rate=params['dropout']
            ).to(device)
            
            # Optimizer with tuned learning rate
            optimizer = torch.optim.Adam(model.parameters(), lr=params['learning_rate'])
            criterion = nn.BCELoss()
            
            # Training loop
            for epoch in range(epochs):
                model.train()
                val_preds, val_labels = [], []
                
                for tcr, epi, tcr_blosum, epi_blosum, label in train_loader:
                    tcr = tcr.to(device)
                    epi = epi.to(device)
                    label = label.to(device)
                    tcr_blosum = tcr_blosum.to(device)
                    epi_blosum = epi_blosum.to(device)
                    
                    optimizer.zero_grad()
                    outputs = model(tcr_blosum, epi_blosum, tcr, epi).squeeze()
                    loss = criterion(outputs, label)
                    loss.backward()
                    optimizer.step()
                
                # Validation
                model.eval()
                with torch.no_grad():
                    for tcr, epi, tcr_blosum, epi_blosum, label in val_loader:
                        tcr = tcr.to(device)
                        epi = epi.to(device)
                        label = label.to(device)
                        tcr_blosum = tcr_blosum.to(device)
                        epi_blosum = epi_blosum.to(device)
                        
                        outputs = model(tcr_blosum, epi_blosum, tcr, epi).squeeze()
                        val_preds.extend(outputs.cpu().numpy())
                        val_labels.extend(label.cpu().numpy())
                
                # Compute metrics for this fold
                val_preds_binary = [1 if p >= 0.5 else 0 for p in val_preds]
                fold_metrics['accuracies'].append(accuracy_score(val_labels, val_preds_binary))
                fold_metrics['precisions'].append(precision_score(val_labels, val_preds_binary))
                fold_metrics['recalls'].append(recall_score(val_labels, val_preds_binary))
                fold_metrics['aucs'].append(roc_auc_score(val_labels, val_preds))
        
        # Compute average metrics for this hyperparameter configuration
        avg_metrics = {
            'mean_accuracy': np.mean(fold_metrics['accuracies']),
            'std_accuracy': np.std(fold_metrics['accuracies']),
            'mean_precision': np.mean(fold_metrics['precisions']),
            'std_precision': np.std(fold_metrics['precisions']),
            'mean_recall': np.mean(fold_metrics['recalls']),
            'std_recall': np.std(fold_metrics['recalls']),
            'mean_auc': np.mean(fold_metrics['aucs']),
            'std_auc': np.std(fold_metrics['aucs']),
            'params': params
        }
        
        results.append(avg_metrics)
        
        print("\nAverage Metrics:")
        print(f"Accuracy: {avg_metrics['mean_accuracy']:.4f} ± {avg_metrics['std_accuracy']:.4f}")
        print(f"Precision: {avg_metrics['mean_precision']:.4f} ± {avg_metrics['std_precision']:.4f}")
        print(f"Recall: {avg_metrics['mean_recall']:.4f} ± {avg_metrics['std_recall']:.4f}")
        print(f"AUC: {avg_metrics['mean_auc']:.4f} ± {avg_metrics['std_auc']:.4f}")
    
    # Sort results by AUC in descending order
    results.sort(key=lambda x: x['mean_auc'], reverse=True)
    
    print("\n--- Best Hyperparameters ---")
    best_result = results[0]
    for k, v in best_result['params'].items():
        print(f"{k}: {v}")
    
    print("\nBest Performance Metrics:")
    print(f"Accuracy: {best_result['mean_accuracy']:.4f} ± {best_result['std_accuracy']:.4f}")
    print(f"Precision: {best_result['mean_precision']:.4f} ± {best_result['std_precision']:.4f}")
    print(f"Recall: {best_result['mean_recall']:.4f} ± {best_result['std_recall']:.4f}")
    print(f"AUC: {best_result['mean_auc']:.4f} ± {best_result['std_auc']:.4f}")
    
    return results, best_result

In [None]:
custom_param_grid = {
    'learning_rate': [0.001, 0.005, 0.000001],
    'batch_size': [32, 64, 128],
    'lstm_hidden_dim': [64, 128, 256],
    'dropout': [0.2, 0.3, 0.4]
}

# Run hyperparameter tuning
results, best_params = hyperparameter_tuning(
    dat, 
    pep_blosum, 
    tcr_blosum, 
    param_grid=custom_param_grid,
    n_splits=5,
    epochs=5
)