In [2]:
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import TFBertModel,BertModel, BertForPreTraining, BertTokenizer, BertConfig
import torch

tokenizer = BertTokenizer.from_pretrained("/kaggle/input/catelmo-bert-tiny/", do_lower_case=False )
model = BertModel.from_pretrained("/kaggle/input/catelmo-bert-tiny/")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def pad_sequence(sequence, tokenizer, max_length=44):
    seq = " ".join(sequence)
    # Padding/truncation using the tokenizer
    tokens = tokenizer(seq, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    return tokens['input_ids'].squeeze(0)

def BERT_embedding(x, TOKENIZER, DEVICE, max_length=44, spad=True):
    if spad:
        padded_input = pad_sequence(x, tokenizer, max_length)
        encoded_input = padded_input.unsqueeze(0).to(device)
    else:
        seq = " ".join(x)
        seq = re.sub(r"[UZOB]", "X", seq)
        encoded_input = tokenizer(seq, return_tensors='pt').to(device)
    if spad:
        output = model(encoded_input) 
    else:
        output = model(**encoded_input)
    return output

column_names = ['antigen', 'cdr3_sequence', 'class']
dat = pd.read_csv('/kaggle/input/tcrdata/data/BAP/tcr_split/train.csv',header=None, names=column_names)
dat['tcr_embeds'] = None
dat['epi_embeds'] = None



In [None]:
for i in tqdm(range(len(dat))):
    # Generating embeddings using the BERT embedding function
    epi_embed = BERT_embedding(dat.antigen[i], tokenizer, device,16, False)[0].reshape(-1,768).mean(dim=0).tolist()  # Assume shape (N, D) or (D,)
    tcr_embed = BERT_embedding(dat.cdr3_sequence[i], tokenizer, device, 16, False)[0].reshape(-1,768).mean(dim=0).tolist()  # Assume shape (N, D) or (D,)
    
    # Saving embeddings as lists in the DataFrame
    dat.at[i, 'epi_embeds'] = epi_embed
    dat.at[i, 'tcr_embeds'] = tcr_embed

In [None]:
# optional: To save/load files
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat2 = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_test_tiny_embeddings_768.pkl')

In [5]:
tcr_embeds = dat['tcr_embeds'].tolist()
epi_embeds = dat['epi_embeds'].tolist()
labels = dat['class'].tolist()

In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm

# Define the model
class CombinedModel(nn.Module):
    def __init__(self, size):
        super(CombinedModel, self).__init__()
        self.n_dim = size
        
        self.branchA = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim*2),
            nn.BatchNorm1d(self.n_dim*2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        self.branchB = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim*2),
            nn.BatchNorm1d(self.n_dim*2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        self.combined = nn.Sequential(
            nn.Linear(self.n_dim*2 * 2, self.n_dim),
            nn.BatchNorm1d(self.n_dim),
            nn.Dropout(0.3),
            nn.SiLU(),
            nn.Linear(self.n_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, inputA, inputB):
        x = self.branchA(inputA)
        y = self.branchB(inputB)
        combined = torch.cat((x, y), dim=1)
        z = self.combined(combined)
        return z

In [7]:
class TCRDataset(Dataset):
    def __init__(self, tcr_embeds, epi_embeds, labels):
        self.tcr_embeds = tcr_embeds
        self.epi_embeds = epi_embeds
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.tcr_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.epi_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.labels[idx], dtype=torch.float32)

In [8]:
train_dataset = TCRDataset(dat['tcr_embeds'].tolist(),
                          dat['epi_embeds'].tolist(),
                          dat['class'].tolist())

In [9]:
test_dataset = TCRDataset(tcr_embeds = dat2['tcr_embeds'].tolist(),
                            epi_embeds = dat2['epi_embeds'].tolist(),
                            labels = dat2['class'].tolist())

In [None]:
FFnmodel = CombinedModel(768).to(device)
optimizer = torch.optim.Adam(FFnmodel.parameters(), lr=0.001)
criterion = nn.BCELoss()

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Training loop
for epoch in range(5): 
    FFnmodel.train()
    epoch_loss = 0
    train_preds, train_labels = [], []
    
    for tcr, epi, label in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        tcr, epi, label = tcr.to(device), epi.to(device), label.to(device)

        optimizer.zero_grad()
        outputs = FFnmodel(tcr, epi).squeeze()
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        train_preds.extend(outputs.detach().cpu().numpy())
        train_labels.extend(label.cpu().numpy())

    # Metrics
    train_preds_binary = [1 if p >= 0.5 else 0 for p in train_preds]
    acc = accuracy_score(train_labels, train_preds_binary)
    prec = precision_score(train_labels, train_preds_binary)
    rec = recall_score(train_labels, train_preds_binary)
    auc = roc_auc_score(train_labels, train_preds)

    print(f"Epoch {epoch+1} - Loss: {epoch_loss/len(train_loader):.4f}, Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")

# Save the trained model
model_save_path = 'FFnmodel_weights.pth'
torch.save(FFnmodel.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
FFnmodel.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for tcr, epi, label in tqdm(test_loader, desc="Testing"):
        tcr, epi, label = tcr.to(device), epi.to(device), label.to(device)
        outputs = FFnmodel(tcr, epi).squeeze()
        test_preds.extend(outputs.cpu().numpy())
        test_labels.extend(label.cpu().numpy())

# Test metrics
test_preds_binary = [1 if p >= 0.5 else 0 for p in test_preds]
acc = accuracy_score(test_labels, test_preds_binary)
prec = precision_score(test_labels, test_preds_binary)
rec = recall_score(test_labels, test_preds_binary)
auc = roc_auc_score(test_labels, test_preds)

print(f"Test Metrics - Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")

df = pd.DataFrame({
    'True_Labels': test_labels,
    'Predicted_Scores': test_preds,
    'Predicted_Binary': test_preds_binary
})

csv_filename = 'base_epi_test_predictions.csv'
df.to_csv(csv_filename, index=False)

print(f"Predictions and labels saved to: {csv_filename}")

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import ParameterGrid, KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
import numpy as np
import itertools

def hyperparameter_tuning(dat, model_class, device, param_grid=None, num_folds=5):
    """
    Perform hyperparameter tuning using grid search with cross-validation
    
    Parameters:
    -----------
    dat : pandas.DataFrame
        DataFrame containing TCR and epitope embeddings and labels
    model_class : class
        PyTorch model class to be tuned
    device : torch.device
        Device to run the model on (cuda/cpu)
    param_grid : dict, optional
        Dictionary of hyperparameters to tune. Default is a predefined grid
    num_folds : int, optional
        Number of folds for cross-validation (default: 5)
    
    Returns:
    --------
    dict
        Best hyperparameters and corresponding performance metrics
    """
    # Default parameter grid if not provided
    if param_grid is None:
        param_grid = {
            'learning_rate': [0.001, 0.005, 0.01],
            'batch_size': [16, 32, 64],
            'dropout_rate': [0.2, 0.3, 0.4],
            'hidden_multiplier': [1, 2, 4],
            'epochs': [3, 5, 7]
        }
    
    # Prepare full dataset
    full_dataset = TCRDataset(
        dat['tcr_embeds'].tolist(),
        dat['epi_embeds'].tolist(),
        dat['class'].tolist()
    )
    
    # Initialize tracking for best hyperparameters
    best_params = None
    best_avg_auc = 0
    all_results = []
    
    # Generate all combinations of hyperparameters
    param_combinations = list(ParameterGrid(param_grid))
    
    for params in param_combinations:
        print("\nTesting Hyperparameters:")
        for k, v in params.items():
            print(f"{k}: {v}")
        
        # Prepare cross-validation
        kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)
        
        fold_metrics = {
            'accuracies': [],
            'precisions': [],
            'recalls': [],
            'aucs': []
        }
        
        for fold, (train_idx, val_idx) in enumerate(kfold.split(full_dataset), 1):
            # Prepare datasets for this fold
            train_dataset = Subset(full_dataset, train_idx)
            val_dataset = Subset(full_dataset, val_idx)
            
            train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=params['batch_size'], shuffle=False)
            
            # Modify model to incorporate dropout rate and hidden layer multiplier
            class TunedCombinedModel(CombinedModel):
                def __init__(self, size):
                    super().__init__(size)
                    # Modify dropout and hidden layer sizes based on tuning parameters
                    dropout_rate = params.get('dropout_rate', 0.3)
                    hidden_multiplier = params.get('hidden_multiplier', 2)
                    
                    self.branchA = nn.Sequential(
                        nn.Linear(self.n_dim, self.n_dim * hidden_multiplier),
                        nn.BatchNorm1d(self.n_dim * hidden_multiplier),
                        nn.Dropout(dropout_rate),
                        nn.SiLU()
                    )
                    
                    self.branchB = nn.Sequential(
                        nn.Linear(self.n_dim, self.n_dim * hidden_multiplier),
                        nn.BatchNorm1d(self.n_dim * hidden_multiplier),
                        nn.Dropout(dropout_rate),
                        nn.SiLU()
                    )
                    
                    self.combined = nn.Sequential(
                        nn.Linear(self.n_dim * hidden_multiplier * 2, self.n_dim),
                        nn.BatchNorm1d(self.n_dim),
                        nn.Dropout(dropout_rate),
                        nn.SiLU(),
                        nn.Linear(self.n_dim, 1),
                        nn.Sigmoid()
                    )
            
            # Initialize model and optimizer
            FFnmodel = TunedCombinedModel(768).to(device)
            optimizer = torch.optim.Adam(FFnmodel.parameters(), lr=params['learning_rate'])
            criterion = nn.BCELoss()
            
            # Training loop
            for epoch in range(params['epochs']):
                FFnmodel.train()
                epoch_loss = 0
                
                for tcr, epi, label in train_loader:
                    tcr, epi, label = tcr.to(device), epi.to(device), label.to(device)
                    
                    optimizer.zero_grad()
                    outputs = FFnmodel(tcr, epi).squeeze()
                    loss = criterion(outputs, label)
                    loss.backward()
                    optimizer.step()
                    
                    epoch_loss += loss.item()
            
            # Validation
            FFnmodel.eval()
            val_preds, val_labels = [], []
            
            with torch.no_grad():
                for tcr, epi, label in val_loader:
                    tcr, epi, label = tcr.to(device), epi.to(device), label.to(device)
                    outputs = FFnmodel(tcr, epi).squeeze()
                    val_preds.extend(outputs.cpu().numpy())
                    val_labels.extend(label.cpu().numpy())
            
            # Compute metrics
            val_preds_binary = [1 if p >= 0.5 else 0 for p in val_preds]
            
            fold_acc = accuracy_score(val_labels, val_preds_binary)
            fold_prec = precision_score(val_labels, val_preds_binary)
            fold_rec = recall_score(val_labels, val_preds_binary)
            fold_auc = roc_auc_score(val_labels, val_preds)
            
            fold_metrics['accuracies'].append(fold_acc)
            fold_metrics['precisions'].append(fold_prec)
            fold_metrics['recalls'].append(fold_rec)
            fold_metrics['aucs'].append(fold_auc)
        
        # Compute average metrics for this hyperparameter set
        avg_metrics = {k: np.mean(v) for k, v in fold_metrics.items()}
        avg_metrics['params'] = params
        all_results.append(avg_metrics)
        
        # Update best parameters
        if avg_metrics['aucs'] > best_avg_auc:
            best_avg_auc = avg_metrics['aucs']
            best_params = params
        
        print("\nAverage Metrics:")
        for metric, value in avg_metrics.items():
            if metric != 'params':
                print(f"{metric.capitalize()}: {value:.4f}")
    
    # Sort results by AUC
    all_results.sort(key=lambda x: x['aucs'], reverse=True)
    
    print("\nTop 3 Hyperparameter Configurations:")
    for result in all_results[:3]:
        print("\nParameters:")
        for k, v in result['params'].items():
            print(f"{k}: {v}")
        print("Metrics:")
        for metric, value in result.items():
            if metric not in ['params', 'aucs']:
                print(f"{metric.capitalize()}: {value:.4f}")
        print(f"AUC: {result['aucs']:.4f}")
    
    print("\nBest Hyperparameters:")
    for k, v in best_params.items():
        print(f"{k}: {v}")
    
    return {
        'best_params': best_params,
        'all_results': all_results
    }

In [None]:
results = hyperparameter_tuning(dat, CombinedModel, device)