In [1]:
import os
import pandas as pd
import numpy as np
import tensorflow as tf

2025-06-18 15:34:36.163774: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-18 15:34:36.201258: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-06-18 15:34:36.201280: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-06-18 15:34:36.201944: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-06-18 15:34:36.207344: I tensorflow/core/platform/cpu_feature_guar

In [2]:
def load_tcr_dataset_from_dir(data_dir, batch_size=100, max_len=40, shuffle=True):
    # Collect files and label metadata
    df_rep = pd.DataFrame(glob(os.path.join(data_dir, '*/*.tsv')), columns=['filepath'])
    df_rep[['HLA', 'Antigen']] = df_rep['filepath'].str.extract(r'/([^/-]+)-([^/-]+)/[^/]+$')

    # Load all TSV files into one DataFrame
    df_tcr = []
    for idx in df_rep.index:
        df = pd.read_csv(df_rep.loc[idx, 'filepath'], sep='\t')
        df['index'] = idx
        df[['HLA', 'Antigen']] = df_rep.loc[idx, ['HLA', 'Antigen']].values
        df_tcr.append(df)
    df_tcr = pd.concat(df_tcr)
    df_tcr['Antigen'] = df_tcr['Antigen'].astype('category')

    # Filter and encode
    idx = ~df_tcr[['aminoAcid', 'Antigen']].isna().any(axis=1)
    X = df_tcr.loc[idx, 'aminoAcid'].values
    y = df_tcr.loc[idx, 'Antigen'].cat.codes.values
    label_map = dict(enumerate(df_tcr.loc[idx, 'Antigen'].cat.categories))

    max_length = min(max(map(len, X)), max_len)
    vocab = sorted(set(''.join(X)))
    aa_dict = {aa: i + 1 for i, aa in enumerate(vocab)}  # +1 for 0 padding
    vocab_size = len(aa_dict) + 1

    # Pad and convert to integers
    X_encoded = np.zeros((len(X), max_length), dtype=np.int32)
    for i, seq in enumerate(X):
        for j, aa in enumerate(seq[:max_length]):
            X_encoded[i, j] = aa_dict.get(aa, 0)

    # Create TensorFlow dataset
    ds = tf.data.Dataset.from_tensor_slices((X_encoded, y))
    ds = ds.map(lambda x, y: (x, tf.one_hot(y, len(label_map))))
    if shuffle:
        ds = ds.shuffle(len(X_encoded))
    ds = ds.batch(batch_size).prefetch(tf.data.AUTOTUNE)

    return ds, aa_dict, label_map, vocab_size