DATA LOADER

In [None]:
import re
import torch
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple

# Tokenizer
def tokenizer(sequence: str) -> List[str]:
    sequence = re.sub(r'\s+', '', str(sequence))
    sequence = re.sub(r'[^ARNDCQEGHILKMFPSTWYVBZX]', '*', sequence)
    return list(sequence)

# Vocabulary mappings
AMINO_MAP = {
    '<pad>': 24, '*': 23, 'A': 0, 'C': 4, 'B': 20,
    'E': 6, 'D': 3, 'G': 7, 'F': 13, 'I': 9, 'H': 8,
    'K': 11, 'M': 12, 'L': 10, 'N': 2, 'Q': 5, 'P': 14,
    'S': 15, 'R': 1, 'T': 16, 'W': 17, 'V': 19, 'Y': 18,
    'X': 22, 'Z': 21
}

AMINO_MAP_REV = [
    'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K',
    'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V', 'B', 'Z', 'X', '*', '@'
]

AMINO_MAP_REV_ = ['A','R','N','D','C','Q','E','G','H','I','L','K',
                 'M','F','P','S','T','W','Y','V','N','Q','*','*','@']

# Padding function
def pad_sequence(sequence: List[int], max_length: int, pad_type: str = "end") -> List[int]:
    pad_token = AMINO_MAP['<pad>']
    if len(sequence) > max_length:
        return sequence[:max_length]
    padding = [pad_token] * (max_length - len(sequence))
    if pad_type == "front":
        return padding + sequence
    elif pad_type == "mid":
        half = len(padding) // 2
        return padding[:half] + sequence + padding[half:]
    else:  # Default is "end"
        return sequence + padding

# Dataset class
class CustomDataset(Dataset):
    def __init__(self, peptides: List[str], tcrs: List[str], labels: List[float] = None, 
                 maxlen_pep: int = 15, maxlen_tcr: int = 25, pad_type: str = "end"):
        self.peptides = [self.encode_sequence(tokenizer(pep), maxlen_pep, pad_type) for pep in peptides]
        self.tcrs = [self.encode_sequence(tokenizer(tcr), maxlen_tcr, pad_type) for tcr in tcrs]
        self.labels = labels if labels is not None else [0.0] * len(peptides)
        self.maxlen_pep = maxlen_pep
        self.maxlen_tcr = maxlen_tcr

    def __len__(self):
        return len(self.peptides)

    def __getitem__(self, idx):
        return {
            "peptides": torch.tensor(self.peptides[idx], dtype=torch.long),
            "tcrs": torch.tensor(self.tcrs[idx], dtype=torch.long),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float32)
        }

    @staticmethod
    def encode_sequence(sequence: List[str], max_length: int, pad_type: str) -> List[int]:
        token_ids = [AMINO_MAP.get(token, AMINO_MAP['*']) for token in sequence]
        return pad_sequence(token_ids, max_length, pad_type)

# Collate function for DataLoader
def collate_fn(batch: List[dict]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    peptides = torch.stack([item["peptides"] for item in batch])
    tcrs = torch.stack([item["tcrs"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return peptides, tcrs, labels

# DataLoader utility
def create_dataloader(peptides: List[str], tcrs: List[str], labels: List[float] = None, 
                      batch_size: int = 32, shuffle: bool = True, 
                      maxlen_pep: int = 15, maxlen_tcr: int = 25, pad_type: str = "end"):
    dataset = CustomDataset(peptides, tcrs, labels, maxlen_pep, maxlen_tcr, pad_type)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=collate_fn)

def load_embedding(filename):
    '''
    read in BLOSUM matrix

    parameters:
        - filename : file containing BLOSUM matrix

    returns:
        - blosum embedding matrix: list
    '''
    if filename is None or filename.lower() == 'none':
        filename = '/kaggle/input/blosum/BLOSUM62.txt'
    
    embedding_file = open(filename, "r")
    lines = embedding_file.readlines()[7:]
    embedding_file.close()

    embedding = [[float(x) for x in l.strip().split()[1:]] for l in lines]
    embedding.append([0.0] * len(embedding[0]))

    return embedding


UTILS

In [None]:
import os
import re
import sys
import time
import math
import torch
import argparse
import numpy as np
from pathlib import Path
import torch.nn.functional as F
from tensorboardX import SummaryWriter
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score, confusion_matrix
# from data_loader import AMINO_MAP, AMINO_MAP_REV, AMINO_MAP_REV_
from collections import defaultdict

BASICITY = {'A': 206.4, 'B': 210.7, 'C': 206.2, 'D': 208.6, 'E': 215.6, 'F': 212.1, 'G': 202.7,
            'H': 223.7, 'I': 210.8, 'K': 221.8, 'L': 209.6, 'M': 213.3, 'N': 212.8, 'P': 214.4,
            'Q': 214.2, 'R': 237.0, 'S': 207.6, 'T': 211.7, 'V': 208.7, 'W': 216.1, 'X': 210.2,
            'Y': 213.1, 'Z': 214.9, '*': 213.1, '@': 0}

HYDROPHOBICITY = {'A': 0.16, 'B': -3.14, 'C': 2.50, 'D': -2.49, 'E': -1.50, 'F': 5.00, 'G': -3.31,
                  'H': -4.63, 'I': 4.41, 'K': -5.00, 'L': 4.76, 'M': 3.23, 'N': -3.79, 'P': -4.92,
                  'Q': -2.76, 'R': -2.77, 'S': -2.85, 'T': -1.08, 'V': 3.02, 'W': 4.88, 'X': 4.59,
                  'Y': 2.00, 'Z': -2.13, '*': -0.25, '@': 0}

HELICITY = {'A': 1.24, 'B': 0.92, 'C': 0.79, 'D': 0.89, 'E': 0.85, 'F': 1.26, 'G': 1.15, 'H': 0.97,
            'I': 1.29, 'K': 0.88, 'L': 1.28, 'M': 1.22, 'N': 0.94, 'P': 0.57, 'Q': 0.96, 'R': 0.95,
            'S': 1.00, 'T': 1.09, 'V': 1.27, 'W': 1.07, 'X': 1.29, 'Y': 1.11, 'Z': 0.91, '*': 1.04, '@': 0}

MUTATION_STABILITY = {'A': 13.0, 'B': 8.5, 'C': 52.0, 'D': 11.0, 'E': 12.0, 'F': 32.0, 'G': 27.0, 'H': 15.0,
                      'I': 10.0, 'K': 24.0, 'L': 34.0, 'M':  6.0, 'N':  6.0, 'P': 20.0, 'Q': 10.0, 'R': 17.0,
                      'S': 10.0, 'T': 11.0, 'V': 17.0, 'W': 55.0, 'X': 20.65, 'Y': 31.0, 'Z': 11.0, '*': 20.65, '@': 0}


def cuda(tensor, is_cuda):

    if is_cuda:
        return tensor.cuda()
    else:
        return tensor


def idxtobool(idx, size, is_cuda):

    V = cuda(torch.zeros(size, dtype=torch.float), is_cuda)
    if len(size) > 2:

        for i in range(size[0]):
            for j in range(size[1]):
                subidx = idx[i, j, :]
                V[i, j, subidx] = float(1)
    elif len(size) == 2:

        for i in range(size[0]):
            subidx = idx[i, :]
            V[i, subidx] = float(1)

    else:
        raise argparse.ArgumentTypeError('len(size) should be larger than 1')

    return V


def create_tensorboard(tensorboard_name):

    tbf = None
    summary_dir = Path('tensorboards').joinpath(tensorboard_name)
    if not summary_dir.exists():
        summary_dir.mkdir(parents=True)
    tbf = SummaryWriter(log_dir=str(summary_dir))

    return tbf


def write_blackbox_output_batchiter(loader, model, wf, device='cpu', ifscore=True):

    model.eval()

    rev_peploader = loader['pep_amino_idx']
    rev_tcrloader = loader['tcr_amino_idx']
    loader = loader['loader']
    for batch in loader:

        X_pep, X_tcr, y = batch.X_pep.to(
            device), batch.X_tcr.to(device), batch.y.to(device)
        score = model(X_pep, X_tcr).data.cpu().tolist()
        score = [s[0] for s in score]
        pred = [round(s) for s in score]

        for i in range(len(pred)):

            pep_seq = ''.join([rev_peploader[x] for x in X_pep[i]])
            pep_seq = re.sub(r'<pad>', '', pep_seq)
            pep_seq = re.sub(r'@', '', pep_seq)
            tcr_seq = ''.join([rev_tcrloader[x] for x in X_tcr[i]])
            tcr_seq = re.sub(r'<pad>', '', tcr_seq)
            tcr_seq = re.sub(r'@', '', tcr_seq)
            if ifscore:
                wf.writerow([pep_seq, tcr_seq, int(y[i]),
                             int(pred[i]), float(score[i])])
            else:
                wf.writerow([pep_seq, tcr_seq, int(pred[i])])

def get_label_batchiter(loader, model, device='cpu'):
    '''
    print classification performance for binary task

    Args:
     loader  - data loader
     model   - classification model
     loss_ft - loss function
    '''
    model.eval()

    loss = 0
    predicted_labels = []
    for batch in loader:

        X_pep, X_tcr, _ = batch
        X_pep = X_pep.to(device)
        X_tcr = X_tcr.to(device)
    
        yhat = model(X_pep, X_tcr)
        predicted_labels.extend((yhat.data.cpu() > 0.5).int().tolist())  # Apply threshold to get predictions

    return predicted_labels

def get_label_prob_batchiter(loader, model, device='cpu'):
    '''
    print classification performance for binary task

    Args:
     loader  - data loader
     model   - classification model
     loss_ft - loss function
    '''
    model.eval()

    loss = 0
    predicted_labels = []
    for batch in loader:

        X_pep, X_tcr, _ = batch
        X_pep = X_pep.to(device)
        X_tcr = X_tcr.to(device)
    
        yhat = model(X_pep, X_tcr)
        predicted_labels.extend(yhat.data.cpu().tolist())  

    return predicted_labels
def get_performance_batchiter(loader, model, device='cpu'):
    '''
    print classification performance for binary task

    Args:
     loader  - data loader
     model   - classification model
     loss_ft - loss function
    '''
    model.eval()

    loss = 0
    score, label = [], []
    for batch in loader:

        X_pep, X_tcr, y = batch
        X_pep = X_pep.to(device)
        X_tcr = X_tcr.to(device)
        y = y.to(device)
        yhat = model(X_pep, X_tcr)
        y = y.unsqueeze(-1).expand_as(yhat)
        loss += F.binary_cross_entropy(yhat, y, reduction='sum').item()
        score.extend(yhat.data.cpu().tolist())
        label.extend(y.data.cpu().tolist())


    perf = get_performance(score, label)
    perf['loss'] = round(loss, 4)

    return perf


def get_performance(score, label):
    '''
    get classification performance for binary task

    Args:
     score - 1D np.array or list
     label - 1D np.array or list
    '''

    accuracy = None
    precision1, precision0 = None, None
    recall1, recall0 = None, None
    f1macro, f1micro = None, None
    auc = None

    # if type(score) is list():
    #    score = np.array(score)
    # if type(label) is list():
    #    label = np.array(label)

    label_pred = [round(s[0]) for s in score]
    accuracy = accuracy_score(label, label_pred)
    precision1 = precision_score(
        label, label_pred, pos_label=1, zero_division=0)
    precision0 = precision_score(
        label, label_pred, pos_label=0, zero_division=0)
    recall1 = recall_score(label, label_pred, pos_label=1, zero_division=0)
    recall0 = recall_score(label, label_pred, pos_label=0, zero_division=0)
    f1macro = f1_score(label, label_pred, average='macro')
    f1micro = f1_score(label, label_pred, average='micro')
    auc = roc_auc_score(np.array(label), np.array(score)) if len(
        np.unique(np.array(label))) != 1 else -1

    ndigits = 4
    performance = {'accuracy': round(accuracy, ndigits),
                   'precision1': round(precision1, ndigits), 'precision0': round(precision0, ndigits),
                   'recall1': round(recall1, ndigits), 'recall0': round(recall0, ndigits),
                   'f1macro': round(f1macro, ndigits), 'f1micro': round(f1micro, ndigits),
                   'auc': round(auc, ndigits)}
    tn, fp, fn, tp = confusion_matrix(label, label_pred, labels=[0, 1]).ravel()
    print(tn, fp, fn, tp)
    return performance


def print_performance(perf, printif=True, writeif=False, boardif=False, **kargs):
    '''
    print classification performance for binary task

    Args:
     per   - dictionary with measure name as keys and performance as values 
             or perf = get_performance(score, label)
     kargs - epoch, loss, global_step
             wf = open(outfile_name, 'w')
             tbf = create_tensorboard(tensorboard_name)
    '''

    measures = sorted(perf.keys())

    if printif:
        maxchrlen = max([len(x) for x in measures])
        for mea in measures:
            print(mea + ' ' * (maxchrlen - len(mea)) +
                  ' {:.4f}'.format(perf[mea]))
        print('')

    if boardif:
        assert 'tbf' in kargs.keys(), 'missing argument: tbf'
        assert 'global_step' in kargs.keys(), 'missing argument: global_step'
        assert 'mode' in kargs.keys(), 'missing argument: mode'
        for mea in measures:
            kargs['tbf'].add_scalars(main_tag='performance/{}'.format(mea),
                                     tag_scalar_dict={
                                         kargs['mode']: perf[mea]},
                                     global_step=kargs['global_step'])

    if writeif:
        assert 'wf' in kargs.keys(), 'missing argument: wf'
        #newrow = [perf[x] for x in measures]
        # kargs['wf'].writerow(newrow)
        kargs['wf'].writerow(perf)
        return kargs['wf']


def str2bool(v):
    """
    Convert string to boolean object

    """

    if v.lower() in ('yes', 'true', 't', 'y', '1', 'True', 'Y', 'Yes', 'YES', 'YEs', 'ye'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0', 'False', 'N', 'NO', 'No'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')


def timeSince(since):
    """
    Credit: https://github.com/1Konny/VIB-pytorch
    """
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60

    return '%dm %ds' % (m, s)


def check_model_name(model_name, file_path='models'):
    """
    Check whether model name is overlapped or not 
    """
    if model_name in os.listdir(file_path):

        valid = {"yes": True, "y": True, "ye": True, 'true': True, 't': True,
                 '1': True, "no": False, "n": False, 'false': False, 'f': False, '0': False}
        sys.stdout.write(
            "The file {} already exists. Do you want to overwrite it? [yes/no]".format(model_name))
        choice = input().lower()

        if choice in valid:
            if not valid[choice]:
                sys.stdout.write(
                    "Please assign another name. (ex. 'original_2.ckpt')")
                model_name = input().lower()
                check_model_name(model_name=model_name, file_path=file_path)

        else:
            sys.stdout.write("Please respond with 'yes' or 'no'\n")
            check_model_name(model_name=model_name, file_path=file_path)

    return model_name


def get_physchem_properties_batchiter(loader, lenpep, lentcr, device='cpu'):
    '''
    print physiochemical properties 

    Args:
     loader  - data loader
     args   - arguments
    '''

    global features
    features = None

    for batch in loader:

        # Input
        X_pep, X_tcr, y_true = batch.X_pep.to(
            device), batch.X_tcr.to(device), batch.y.to(device)

        get_physchem_properties(X_pep[np.where(y_true == 1.0)].tolist(),
                                X_tcr[np.where(y_true == 1.0)].tolist(),
                                lenpep, lentcr,
                                exclude=set(['X', '*']))

    return features


def get_physchem_properties(pep_batch, tcr_batch,
                            max_len_pep, max_len_tcr, exclude):

    global features

    pep_batch = num2seq(pep_batch, AMINO_MAP_REV_, max_len=max_len_pep,
                        align=False, exclude=exclude)
    tcr_batch = num2seq(tcr_batch, AMINO_MAP_REV, max_len=max_len_tcr,
                        align=False, exclude=exclude)

    if not features:
        features = defaultdict(lambda: defaultdict(list))

    for pep, tcr in zip(pep_batch, tcr_batch):
        features[pep]['tcr'].append(tcr)
        features[pep]['basicity'].append([BASICITY[aa] for aa in tcr])
        features[pep]['hydrophobicity'].append(
            [HYDROPHOBICITY[aa] for aa in tcr])
        features[pep]['helicity'].append([HELICITY[aa] for aa in tcr])
        features[pep]['mutation_stability'].append(
            [MUTATION_STABILITY[aa] for aa in tcr])

        # features[pep]['length'].append(len(tcr))
        # features[pep]['fast_mass'].append(mass.fast_mass(tcr))
        # features[pep]['pI'].append(electrochem.pI(tcr))
        #ac_comp = parser.amino_acid_composition(tcr)
        # for aa in AMINO_MAP_REV[:-2]:
        #    features[pep][aa].append(ac_comp[aa])


def print_physchem_properties(perf, wf, measures=None):

    if measures is None:
        measures = sorted(perf.keys())

    wf.writerow(measures)
    for i in range(len(perf[measures[0]])):
        wf.writerow([perf[mea][i] for mea in measures])


def seq2num(seq_list, mapping, max_len=None, align=True):

    num_list = []

    if align:

        for seq in seq_list:

            if max_len is None:
                num = [mapping[seq[i]] for i in range(len(num))]
            elif max_len > len(seq):
                num = [mapping[seq[i]] for i in range(
                    len(seq))] + [mapping['<pad>'] for _ in range(max_len - len(seq))]
            else:
                num = [mapping[seq[i]] for i in range(max_len)]

            num_list.append(num)

    else:

        for seq in seq_list:

            if max_len is None or max_len > len(seq):
                num = [mapping[seq[i]]
                       for i in range(len(num)) if seq[i] != '<pad>']
            else:
                num = [mapping[seq[i]]
                       for i in range(len(seq)) if seq[i] != '<pad>']

            num_list.append(num)

    return num_list


def num2seq(num_list, mapping, max_len=None, align=True, exclude=set(['@'])):

    seq_list = []

    if align:

        for num in num_list:

            if max_len is None:
                seq = [mapping[num[i]] for i in range(len(num))]
            elif max_len > len(num):
                seq = [mapping[num[i]] for i in range(len(num))]
            else:
                seq = [mapping[num[i]] for i in range(max_len)]

            seq_list.append(''.join(seq))

    else:

        for num in num_list:

            if max_len is None or max_len > len(num):
                seq = [mapping[num[i]]
                       for i in range(len(num)) if mapping[num[i]] not in exclude]
            else:
                seq = [mapping[num[i]]
                       for i in range(max_len) if mapping[num[i]] not in exclude]

            seq_list.append(''.join(seq))

    return seq_list

MODEL

In [None]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self, embedding,dim_hidden1, dim_hidden2, dropout1, dropout2):
        super(Net, self).__init__()

        # Embedding Layer
        self.num_amino = len(embedding)
        self.embedding_dim = len(embedding[0])
        self.embedding = nn.Embedding(self.num_amino, self.embedding_dim, padding_idx=self.num_amino - 1)
        self.embedding = self.embedding.from_pretrained(torch.FloatTensor(embedding), freeze=False)

        # Self-Attention Layers
        self.attn_tcr = nn.MultiheadAttention(embed_dim=self.embedding_dim, num_heads=4)
        self.attn_pep = nn.MultiheadAttention(embed_dim=self.embedding_dim, num_heads=4)

        # Dense Layers
        self.size_hidden1_dense = dim_hidden1
        self.size_hidden2_dense = dim_hidden2
        self.net_pep_dim = 22 * self.embedding_dim
        self.net_tcr_dim = 20 * self.embedding_dim
        self.net = nn.Sequential(
            nn.Linear(self.net_pep_dim + self.net_tcr_dim,
                      self.size_hidden1_dense),
            nn.BatchNorm1d(self.size_hidden1_dense),
            nn.Dropout(dropout1),
            nn.SiLU(),
            nn.Linear(self.size_hidden1_dense, self.size_hidden2_dense),
            nn.BatchNorm1d(self.size_hidden2_dense),
            nn.Dropout(dropout2),
            nn.SiLU(),
            nn.Linear(self.size_hidden2_dense, 1),
            nn.Sigmoid()
        )

    def forward(self, pep, tcr):
        # Embedding
        pep = self.embedding(pep)  # batch * len * dim
        tcr = self.embedding(tcr)  # batch * len * dim

        # Transpose (seq_len, batch, embed_dim)
        pep = torch.transpose(pep, 0, 1)
        tcr = torch.transpose(tcr, 0, 1)

        # Self-Attention
        pep, _ = self.attn_pep(pep, pep, pep)
        tcr, _ = self.attn_tcr(tcr, tcr, tcr)

        # Transpose back (batch, seq_len, embed_dim)
        pep = torch.transpose(pep, 0, 1)
        tcr = torch.transpose(tcr, 0, 1)

        # Linear Layers
        pep = pep.reshape(-1, 1, pep.size(-2) * pep.size(-1))
        tcr = tcr.reshape(-1, 1, tcr.size(-2) * tcr.size(-1))
        peptcr = torch.cat((pep, tcr), -1).squeeze(-2)
        peptcr = self.net(peptcr)

        return peptcr


In [None]:
def train(model, device, train_loader, optimizer, epoch):

    model.train()

    for batch in train_loader:

        # x_pep, x_tcr, y = batch.X_pep.to(
        #     device), batch.X_tcr.to(device), batch.y.to(device)
        X_pep, X_tcr, y = batch
        X_pep = X_pep.to(device)
        X_tcr = X_tcr.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        yhat = model(X_pep, X_tcr)
        y = y.unsqueeze(-1).expand_as(yhat)
        loss = F.binary_cross_entropy(yhat, y)
        loss.backward()
        optimizer.step()

    # if epoch % 2 == 1:
    #     print('[TRAIN] Epoch {} Loss {:.4f}'.format(epoch, loss.item()))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
embedding_matrix = load_embedding(None)

In [None]:
import torch.optim as optim

torch.manual_seed(42)

In [None]:
import pandas as pd
column_names = ['antigen', 'cdr3_sequence', 'class']
df = pd.read_csv('/kaggle/input/data-tcr/data/BAP/tcr_split/train.csv',header=None, names=column_names)

In [None]:
#Hyper-Parameter Tuning with 5 fold cross validation
import time
import torch
import os
import csv
from collections import deque
from sklearn.model_selection import KFold, train_test_split
from torch.optim import Adam, SGD
from itertools import product

# Initialize CSV for metrics
file_name = 'ATMTCR_tcr_base_model_metrics_final.csv'

# Open the file in append mode
wf_open = open(file_name, 'a+', newline='')
best_performance_data = []

# Move to the beginning of the file and check if it's empty
wf_open.seek(0)
is_empty = len(wf_open.read().strip()) == 0

# Define column names
wf_colnames = ['fold', 'epoch', 'loss', 'accuracy', 'precision1', 'precision0',
               'recall1', 'recall0', 'f1macro', 'f1micro', 'auc']

# Initialize DictWriter
wf = csv.DictWriter(wf_open, fieldnames=wf_colnames, delimiter='\t')

# Write the header only if the file is empty
if is_empty:
    wf.writeheader()

# Hyperparameter search space


hyperparameter_space = {
    'padding': ['front', 'mid','end'],
    'dropout1': [0.5, 0.6],
    'dropout2': [0.25, 0.3, 0.5],
    'maxlen_tcr': [10, 15, 18, 20, 22],
    'maxlen_pep': [10, 15, 18, 20, 22],
    'dim_hidden1': [256, 512, 1024, 2048, 4096],
    'dim_hidden2': [256, 512, 1024, 2048]
}
best_hyperparameters={}
# Prepare data
pepts = df['antigen'].tolist()
tcrs = df['cdr3_sequence'].tolist()
labels = df['class'].tolist()

kf = KFold(n_splits=5, shuffle=True, random_state=42)

t0 = time.time()

best_loss = float('inf')
best_params = None
best_model = None

param_combinations = list(product(*hyperparameter_space.values()))
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

for params in param_combinations:
    param_dict = dict(zip(hyperparameter_space.keys(), params))
    print(f"Testing parameters: {param_dict}")
    padding, dropout_dense1, dropout_dense2, maxlen_tcr, maxlen_pep, dense1_dim, dense2_dim = params

    fold = 1
    # Cross-validation
    for train_indices, test_indices in kf.split(pepts):
        print(f"\nStarting fold {fold}...")

        # Create initial training and testing splits
        X_train_pep = [pepts[i] for i in train_indices]
        X_train_tcr = [tcrs[i] for i in train_indices]
        y_train = [labels[i] for i in train_indices]

        X_test_pep = [pepts[i] for i in test_indices]
        X_test_tcr = [tcrs[i] for i in test_indices]
        y_test = [labels[i] for i in test_indices]

        # Create DataLoaders (no validation set)
        Train_Loader = create_dataloader(X_train_pep, X_train_tcr, y_train,
                                                 batch_size=32, shuffle=True, maxlen_pep=maxlen_pep, maxlen_tcr=maxlen_tcr, pad_type=padding)

        Test_Loader = create_dataloader(X_test_pep, X_test_tcr, y_test,
                                                batch_size=32, shuffle=False, maxlen_pep=maxlen_pep, maxlen_tcr=maxlen_tcr, pad_type=padding)

        # Initialize model, optimizer, and other components for each fold
        lossArraySize = 10
        lossArray = deque([sys.maxsize], maxlen=lossArraySize)
               
        best_perf_test = None  # To store the best test performance

        for epoch in range(1, 201):  
            epoch_loss = 0
            model = Net(embedding=embedding_matrix,dim_hidden1=dense1_dim, dim_hidden2=dense2_dim,dropout1=dropout_dense1, dropout2=dropout_dense2).to(device)
            optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.5, 0.999), eps=1e-8)

            # Train the model
            train(model, device, Train_Loader, optimizer, epoch)

            # Evaluate the model on the test set
            perf_test = get_performance_batchiter(Test_Loader, model, device)
            print(f"[TEST - Fold {fold}] Epoch {epoch} ----------------")
            print_performance(perf_test, printif=False, writeif=False)

            # Check for early stopping
            lossArray.append(perf_test['loss'])
            average_loss_change = sum(np.abs(np.diff(lossArray))) / lossArraySize
            if epoch > 5 and average_loss_change < 10:
                print(f"Early stopping at epoch {epoch} for fold {fold}")
                break

            # Update best test performance and save it to CSV
            if best_perf_test is None or perf_test['loss'] < best_perf_test['loss']:
                best_perf_test = perf_test
            # Save the best test performance for the fold
                wf.writerow({
                            'fold': fold,
                            'epoch': epoch,
                            'loss': best_perf_test['loss'],
                            'accuracy': best_perf_test['accuracy'],
                            'precision1': best_perf_test['precision1'],
                            'precision0': best_perf_test['precision0'],
                            'recall1': best_perf_test['recall1'],
                            'recall0': best_perf_test['recall0'],
                            'f1macro': best_perf_test['f1macro'],
                            'f1micro': best_perf_test['f1micro'],
                            'auc': best_perf_test['auc']
                })
                best_performance_data.append(best_perf_test)

                print(f"Best test performance for fold {fold} updated at epoch {epoch}")

            # Save the model checkpoint if it has the best validation loss
            if not os.path.exists('./checkpoints'):
                os.makedirs('./checkpoints')
            if perf_test['loss'] < best_loss:
                best_loss = perf_test['loss']
                checkpoint_path = os.path.join(f"checkpoints/ATMTCR_tcr_base.pt")
                torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'loss': perf_test['loss'],
                    }, checkpoint_path)
                print(f"Checkpoint saved at {checkpoint_path}")

                # Update the best hyperparameters
                best_hyperparameters=param_dict

        print(f"Fold {fold} completed.")
        fold += 1


print(f"Training completed. Total time: {timeSince(t0)}")
wf_open.close()


In [None]:
# Output the best hyperparameters and performance
print("\nBest Hyperparameters and Performance:")
print(best_hyperparameters)

In [None]:
#Training Complete - Now Load Test Data
df = pd.read_csv('/kaggle/input/data-tcr/data/BAP/tcr_split/test.csv',header=None, names=column_names)
test_pepts = df['antigen'].tolist()
test_tcrs = df['cdr3_sequence'].tolist()
dummy_labels = np.zeros(len(test_pepts))  # Create dummy labels for testing
Test_Loader = create_dataloader(test_pepts, test_tcrs, dummy_labels, batch_size=32,shuffle=False,maxlen_pep=22,maxlen_tcr=20, pad_type="end")

In [None]:
def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    
    # Load the model state_dict
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Load the optimizer state_dict
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    # Get the epoch and loss
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']
    
    return model, optimizer, epoch, loss

In [None]:
model, optim, epoch, loss = load_checkpoint(model, optimizer, "checkpoints/ATMTCR_tcr_base.pt")

In [None]:
torch.save(model.state_dict(), 'ATMTCR_tcr_base_model_weights.pt')


In [None]:
labels = get_label_batchiter(
        Test_Loader, model, device)

# Save predictions to a CSV file
import pandas as pd

pred_df = pd.DataFrame(labels, columns=['Predicted Probability'])
pred_df.to_csv('base_tcr_test_predictions.csv', index=False)

print("Predictions saved to 'test_predictions.csv'")

In [None]:
print(len(test_pepts))