In [2]:
import re
import pandas as pd
import numpy as np
from tqdm import tqdm
from transformers import TFBertModel,BertModel, BertForPreTraining, BertTokenizer, BertConfig
import torch

tokenizer = BertTokenizer.from_pretrained("/kaggle/input/catelmo-bert-tiny/", do_lower_case=False )
model = BertModel.from_pretrained("/kaggle/input/catelmo-bert-tiny/")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def pad_sequence(sequence, tokenizer, max_length=44):
    seq = " ".join(sequence)
    # Padding/truncation using the tokenizer
    tokens = tokenizer(seq, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
    return tokens['input_ids'].squeeze(0)

def BERT_embedding(x, TOKENIZER, DEVICE, max_length=44, spad=True):
    if spad:
        padded_input = pad_sequence(x, tokenizer, max_length)
        encoded_input = padded_input.unsqueeze(0).to(device)  # Add batch dimension
    else:
        seq = " ".join(x)
        seq = re.sub(r"[UZOB]", "X", seq)
        encoded_input = tokenizer(seq, return_tensors='pt').to(device)
    if spad:
        output = model(encoded_input) 
    else:
        output = model(**encoded_input)
    return output

column_names = ['antigen', 'cdr3_sequence', 'class']
dat = pd.read_csv('/kaggle/input/tcrdata/data/BAP/tcr_split/train.csv',header=None, names=column_names)
dat['tcr_embeds'] = None
dat['epi_embeds'] = None



In [None]:
for i in tqdm(range(len(dat))):
    # Generating embeddings using the BERT embedding function
    epi_embed = BERT_embedding(dat.antigen[i], tokenizer, device,16, False)[0].reshape(-1,768).mean(dim=0).tolist()  # Assume shape (N, D) or (D,)
    tcr_embed = BERT_embedding(dat.cdr3_sequence[i], tokenizer, device, 16, False)[0].reshape(-1,768).mean(dim=0).tolist()  # Assume shape (N, D) or (D,)

    # Saving embeddings as lists in the DataFrame
    dat.at[i, 'epi_embeds'] = epi_embed  # Use .at for row assignment
    dat.at[i, 'tcr_embeds'] = tcr_embed


In [4]:
# optional: To save/load files
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat.to_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_train_tiny_embeddings_768.pkl')
# dat2 = pd.read_pickle('/kaggle/input/bert-embeddings-catelmo/epi_test_tiny_embeddings_768.pkl')

In [5]:
tcr_embeds = dat['tcr_embeds'].tolist()
epi_embeds = dat['epi_embeds'].tolist()
labels = dat['class'].tolist()

In [8]:
train_dataset = TCRDataset(dat['tcr_embeds'].tolist(),
                          dat['epi_embeds'].tolist(),
                          dat['class'].tolist())

In [9]:
test_dataset = TCRDataset(tcr_embeds = dat2['tcr_embeds'].tolist(),
                            epi_embeds = dat2['epi_embeds'].tolist(),
                            labels = dat2['class'].tolist())

In [6]:
import re
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple

# Tokenizer
def tokenizer(sequence: str) -> List[str]:
    sequence = re.sub(r'\s+', '', str(sequence))
    sequence = re.sub(r'[^ARNDCQEGHILKMFPSTWYVBZX]', '*', sequence)
    return list(sequence)

# Vocabulary mappings
AMINO_MAP = {
    '<pad>': 24, '*': 23, 'A': 0, 'C': 4, 'B': 20,
    'E': 6, 'D': 3, 'G': 7, 'F': 13, 'I': 9, 'H': 8,
    'K': 11, 'M': 12, 'L': 10, 'N': 2, 'Q': 5, 'P': 14,
    'S': 15, 'R': 1, 'T': 16, 'W': 17, 'V': 19, 'Y': 18,
    'X': 22, 'Z': 21
}

AMINO_MAP_REV = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K',
    'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'B', 'Z', 'X', '*', '@'
]

AMINO_MAP_REV_ = ['A','R','N','D','C','Q','E','G','H','I','L','K',
                 'M','F','P','S','T','W','Y','V','N','Q','*','*','@']

# Padding function
def pad_sequence(sequence: List[int], max_length: int, pad_type: str = "end") -> List[int]:
    pad_token = AMINO_MAP['<pad>']
    if len(sequence) > max_length:
        return sequence[:max_length]
    padding = [pad_token] * (max_length - len(sequence))
    if pad_type == "front":
        return padding + sequence
    elif pad_type == "mid":
        half = len(padding) // 2
        return padding[:half] + sequence + padding[half:]
    else:
        return sequence + padding



def encode_sequence(sequence: List[str], max_length: int, pad_type: str) -> List[int]:
        token_ids = [AMINO_MAP.get(token, AMINO_MAP['*']) for token in sequence]
        return pad_sequence(token_ids, max_length, pad_type)


def load_embedding(filename):
    if filename is None or filename.lower() == 'none':
        filename = '/kaggle/input/blosum/BLOSUM62.txt'
    
    embedding_file = open(filename, "r")
    lines = embedding_file.readlines()[7:]
    embedding_file.close()

    embedding = [[float(x) for x in l.strip().split()[1:]] for l in lines]
    embedding.append([0.0] * len(embedding[0]))

    return embedding


In [7]:
import torch
import torch.nn as nn
embedding = load_embedding(None)
num_amino = len(embedding)
embedding_dim = len(embedding[0])
nn_embedding = nn.Embedding(num_amino, embedding_dim, padding_idx=num_amino-1)
embedding_model = nn_embedding.from_pretrained(torch.FloatTensor(embedding), freeze=False)

In [None]:
embedding_model.to(device)

In [9]:
peptides = [encode_sequence(tokenizer(pep), 15, "end") for pep in dat['antigen'].tolist()]
tcrs = [encode_sequence(tokenizer(dat['cdr3_sequence']), 25, "end")for tcr in dat['cdr3_sequence'].tolist()] 

In [12]:
with torch.no_grad():
    pep_blosum = embedding_model(torch.tensor(peptides).to(device))
    tcr_blosum = embedding_model(torch.tensor(tcrs).to(device))
pep_blosum = pep_blosum.tolist()
tcr_blosum = tcr_blosum.tolist()

In [10]:
class TCRDataset(Dataset):
    def __init__(self, tcr_embeds, epi_embeds,pep_blosum, tcr_blosum, labels):
        self.tcr_embeds = tcr_embeds
        self.epi_embeds = epi_embeds
        self.pep_blosum = pep_blosum
        self.tcr_blosum = tcr_blosum
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return torch.tensor(self.tcr_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.epi_embeds[idx], dtype=torch.float32), \
               torch.tensor(self.pep_blosum[idx], dtype=torch.float32), \
               torch.tensor(self.tcr_blosum[idx], dtype=torch.float32), \
               torch.tensor(self.labels[idx], dtype=torch.float32)

In [13]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(dat, test_size=0.2, random_state=42)
train_pep,test_pep = train_test_split(pep_blosum, test_size=0.2, random_state=42)
train_tcr, test_tcr = train_test_split(tcr_blosum, test_size=0.2, random_state=42)

train_dataset = TCRDataset(train_df['tcr_embeds'].tolist(),
                           train_df['epi_embeds'].tolist(),
                           train_pep,
                           train_tcr,
                           train_df['class'].tolist())

test_dataset = TCRDataset(test_df['tcr_embeds'].tolist(),
                          test_df['epi_embeds'].tolist(),
                          test_pep,
                          test_tcr,
                          test_df['class'].tolist())

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BindingAffinityModel(nn.Module):
    def __init__(self, size, blosum_embedding_dim, lstm_hidden_dim=128, ffn_hidden_dim=256):
        super(BindingAffinityModel, self).__init__()
        self.n_dim = size
        
        # LSTM branch for BLOSUM embeddings of TCR sequences
        self.tcr_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                hidden_size=lstm_hidden_dim, 
                                num_layers=1, 
                                batch_first=True, 
                                bidirectional=True)
        
        # LSTM branch for BLOSUM embeddings of epitope sequences
        self.epitope_lstm = nn.LSTM(input_size=blosum_embedding_dim, 
                                    hidden_size=lstm_hidden_dim, 
                                    num_layers=1, 
                                    batch_first=True, 
                                    bidirectional=True)
        
        # Branch A (FFN for TCR embeddings)
        self.branchA = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim * 2),
            nn.BatchNorm1d(self.n_dim * 2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        # Branch B (FFN for epitope embeddings)
        self.branchB = nn.Sequential(
            nn.Linear(self.n_dim, self.n_dim * 2),
            nn.BatchNorm1d(self.n_dim * 2),
            nn.Dropout(0.3),
            nn.SiLU()
        )
        
        # Combined FFN layers
        self.combined = nn.Sequential(
            nn.Linear(self.n_dim * 2 * 2 + 4 * lstm_hidden_dim, self.n_dim),
            nn.BatchNorm1d(self.n_dim),
            nn.Dropout(0.3),
            nn.SiLU(),
            nn.Linear(self.n_dim, 1),
            nn.Sigmoid()
        )
    
    def forward(self, tcr_blosum_seq, epitope_blosum_seq, inputA, inputB):
        """
        Parameters:
        - tcr_blosum_seq: Tensor [batch_size, seq_len, blosum_embedding_dim] (BLOSUM TCR)
        - epitope_blosum_seq: Tensor [batch_size, seq_len, blosum_embedding_dim] (BLOSUM Epitope)
        - inputA: Tensor [batch_size, size] (TCR BERT Embedding)
        - inputB: Tensor [batch_size, size] (Epitope BERT Embedding)
        """
        # Process TCR BLOSUM sequence through LSTM
        tcr_lstm_out, _ = self.tcr_lstm(tcr_blosum_seq)  # Shape: [batch_size, seq_len, 2*lstm_hidden_dim]
        tcr_lstm_out = tcr_lstm_out[:, -1, :]  # Take last hidden state
        
        # Process Epitope BLOSUM sequence through LSTM
        epitope_lstm_out, _ = self.epitope_lstm(epitope_blosum_seq)  # Shape: [batch_size, seq_len, 2*lstm_hidden_dim]
        epitope_lstm_out = epitope_lstm_out[:, -1, :]  # Take last hidden state
        
        # Process BERT embeddings through FFN branches
        x = self.branchA(inputA)  # Shape: [batch_size, n_dim * 2]
        y = self.branchB(inputB)  # Shape: [batch_size, n_dim * 2]
        
        # Concatenate LSTM outputs with branch outputs
        combined = torch.cat((x, y, tcr_lstm_out, epitope_lstm_out), dim=1)  # Shape: [batch_size, n_dim*4 + 4*lstm_hidden_dim]
        
        # Pass through combined FFN
        z = self.combined(combined)  # Final output
        return z


In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
FFnmodel = BindingAffinityModel(768, 24).to(device)
optimizer = torch.optim.Adam(FFnmodel.parameters(), lr=0.001)
criterion = nn.BCELoss()

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Training loop
for epoch in range(5):
    FFnmodel.train()
    epoch_loss = 0
    train_preds, train_labels = [], []
    
    for tcr, epi, tcr_blosum, epi_blosum, label in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        tcr, epi, label,tcr_blosum, epi_blosum = tcr.to(device), epi.to(device), label.to(device), tcr_blosum.to(device), epi_blosum.to(device)

        optimizer.zero_grad()
        outputs = FFnmodel(tcr_blosum, epi_blosum,tcr, epi).squeeze()
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        train_preds.extend(outputs.detach().cpu().numpy())
        train_labels.extend(label.cpu().numpy())

    # Metrics
    train_preds_binary = [1 if p >= 0.5 else 0 for p in train_preds]
    acc = accuracy_score(train_labels, train_preds_binary)
    prec = precision_score(train_labels, train_preds_binary)
    rec = recall_score(train_labels, train_preds_binary)
    auc = roc_auc_score(train_labels, train_preds)

    print(f"Epoch {epoch+1} - Loss: {epoch_loss/len(train_loader):.4f}, Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")

In [None]:
FFnmodel.eval()
test_preds, test_labels = [], []

with torch.no_grad():
    for tcr, epi, tcr_blosum, epi_blosum, label in tqdm(test_loader, desc="Testing"):
        tcr, epi, label, tcr_blosum, epi_blosum = tcr.to(device), epi.to(device), label.to(device), tcr_blosum.to(device), epi_blosum.to(device)
        outputs = FFnmodel(tcr_blosum, epi_blosum, tcr, epi).squeeze()
        test_preds.extend(outputs.cpu().numpy())
        test_labels.extend(label.cpu().numpy())

# Test metrics
test_preds_binary = [1 if p >= 0.5 else 0 for p in test_preds]
acc = accuracy_score(test_labels, test_preds_binary)
prec = precision_score(test_labels, test_preds_binary)
rec = recall_score(test_labels, test_preds_binary)
auc = roc_auc_score(test_labels, test_preds)

print(f"Test Metrics - Acc: {acc:.4f}, Prec: {prec:.4f}, Rec: {rec:.4f}, AUC: {auc:.4f}")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, SubsetRandomSampler
import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
from tqdm import tqdm

def cross_validate(dat, pep_blosum, tcr_blosum, n_splits=5, epochs=5, batch_size=32, learning_rate=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    labels = dat['class'].tolist()
    
    cv_results = {
        'fold_accuracies': [],
        'fold_precisions': [],
        'fold_recalls': [],
        'fold_aucs': [],
        'cv_accuracy_mean': 0,
        'cv_accuracy_std': 0,
        'cv_precision_mean': 0,
        'cv_precision_std': 0,
        'cv_recall_mean': 0,
        'cv_recall_std': 0,
        'cv_auc_mean': 0,
        'cv_auc_std': 0
    }
    
    # Stratified K-Fold
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    for fold, (train_index, val_index) in enumerate(skf.split(dat, labels), 1):
        print(f"\n--- Fold {fold} ---")

        train_df = dat.iloc[train_index]
        val_df = dat.iloc[val_index]
        
        train_pep = [pep_blosum[i] for i in train_index]
        val_pep = [pep_blosum[i] for i in val_index]
        train_tcr = [tcr_blosum[i] for i in train_index]
        val_tcr = [tcr_blosum[i] for i in val_index]
        
        train_dataset = TCRDataset(
            train_df['tcr_embeds'].tolist(),
            train_df['epi_embeds'].tolist(),
            train_pep,
            train_tcr,
            train_df['class'].tolist()
        )
        
        val_dataset = TCRDataset(
            val_df['tcr_embeds'].tolist(),
            val_df['epi_embeds'].tolist(),
            val_pep,
            val_tcr,
            val_df['class'].tolist()
        )
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        model = BindingAffinityModel(768, 24).to(device)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.BCELoss()
        
        # Training loop
        for epoch in range(epochs):
            model.train()
            train_loss = 0
            train_preds, train_labels = [], []
            
            for tcr, epi, tcr_blosum, epi_blosum, label in tqdm(train_loader, desc=f"Fold {fold}, Epoch {epoch+1}"):
                tcr = tcr.to(device)
                epi = epi.to(device)
                label = label.to(device)
                tcr_blosum = tcr_blosum.to(device)
                epi_blosum = epi_blosum.to(device)
                
                optimizer.zero_grad()
                outputs = model(tcr_blosum, epi_blosum, tcr, epi).squeeze()
                loss = criterion(outputs, label)
                loss.backward()
                optimizer.step()
                
                train_loss += loss.item()
                train_preds.extend(outputs.detach().cpu().numpy())
                train_labels.extend(label.cpu().numpy())
            
            # Validation
            model.eval()
            val_preds, val_labels = [], []
            
            with torch.no_grad():
                for tcr, epi, tcr_blosum, epi_blosum, label in val_loader:
                    tcr = tcr.to(device)
                    epi = epi.to(device)
                    label = label.to(device)
                    tcr_blosum = tcr_blosum.to(device)
                    epi_blosum = epi_blosum.to(device)
                    
                    outputs = model(tcr_blosum, epi_blosum, tcr, epi).squeeze()
                    val_preds.extend(outputs.cpu().numpy())
                    val_labels.extend(label.cpu().numpy())

            val_preds_binary = [1 if p >= 0.5 else 0 for p in val_preds]
            fold_acc = accuracy_score(val_labels, val_preds_binary)
            fold_prec = precision_score(val_labels, val_preds_binary)
            fold_rec = recall_score(val_labels, val_preds_binary)
            fold_auc = roc_auc_score(val_labels, val_preds)
            
            print(f"Fold {fold}, Epoch {epoch+1} - Val Metrics:")
            print(f"Loss: {train_loss/len(train_loader):.4f}")
            print(f"Accuracy: {fold_acc:.4f}")
            print(f"Precision: {fold_prec:.4f}")
            print(f"Recall: {fold_rec:.4f}")
            print(f"AUC: {fold_auc:.4f}")
        
        cv_results['fold_accuracies'].append(fold_acc)
        cv_results['fold_precisions'].append(fold_prec)
        cv_results['fold_recalls'].append(fold_rec)
        cv_results['fold_aucs'].append(fold_auc)

    cv_results['cv_accuracy_mean'] = np.mean(cv_results['fold_accuracies'])
    cv_results['cv_accuracy_std'] = np.std(cv_results['fold_accuracies'])
    cv_results['cv_precision_mean'] = np.mean(cv_results['fold_precisions'])
    cv_results['cv_precision_std'] = np.std(cv_results['fold_precisions'])
    cv_results['cv_recall_mean'] = np.mean(cv_results['fold_recalls'])
    cv_results['cv_recall_std'] = np.std(cv_results['fold_recalls'])
    cv_results['cv_auc_mean'] = np.mean(cv_results['fold_aucs'])
    cv_results['cv_auc_std'] = np.std(cv_results['fold_aucs'])

    print("\n--- Cross-Validation Summary ---")
    print(f"Accuracy: {cv_results['cv_accuracy_mean']:.4f} ± {cv_results['cv_accuracy_std']:.4f}")
    print(f"Precision: {cv_results['cv_precision_mean']:.4f} ± {cv_results['cv_precision_std']:.4f}")
    print(f"Recall: {cv_results['cv_recall_mean']:.4f} ± {cv_results['cv_recall_std']:.4f}")
    print(f"AUC: {cv_results['cv_auc_mean']:.4f} ± {cv_results['cv_auc_std']:.4f}")
    
    return cv_results
