# Train the TCR-pHLA predictor with a K-fold strategy and make predictions on a test set

In [1]:
import pandas as pd
import numpy as np
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import negative_sampling_phla, preprocess_input_data
from sklearn.model_selection import StratifiedKFold

### Load data

We provided only a small subset here for demonstration purposes. For full dataset, please refer to the [Zenodo repository](https://zenodo.org/records/18002170)

In [2]:
import pandas as pd
df = pd.read_csv('examples/tcrphla_train_set_example.csv')

### Process input data

In [3]:
df['label'].value_counts()

label
0    303
1     57
Name: count, dtype: int64

In [4]:
df = preprocess_input_data(df)

### 5-fold train/validation split

In [5]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
df_train_folds, df_val_folds = [], []

tcrphla_fold = 5 # number of folds used in TCR-pHLA model training

skf = StratifiedKFold(n_splits=tcrphla_fold, shuffle=True, random_state=42)

df_train_folds = []
df_val_folds = []

for i, (train_idx, val_idx) in enumerate(skf.split(df, df["label"])):
    
    df_train_fold = df.iloc[train_idx].reset_index(drop=True)
    df_val_fold   = df.iloc[val_idx].reset_index(drop=True)

    # negative sampling for validation set
    df_val_fold_neg = negative_sampling_phla(df_val_fold, random_state=i)
    df_val_fold = pd.concat([df_val_fold, df_val_fold_neg], axis=0).reset_index(drop=True)

    df_train_folds.append(df_train_fold)
    df_val_folds.append(df_val_fold)

### Initialize pHLA prediction system

TCR-pHLA prediction system builds upon the pre-trained pHLA prediction system. Here we initialize the pHLA prediction system first.

In [6]:
pep_hla_system = StriMap_pHLA(device=device, cache_dir='cache')
pep_hla_system.prepare_embeddings(df.reset_index(drop=True))

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 1 unique peptides
  - 1 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 54.29it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 2225.09it/s]
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache
✓ All embeddings prepared!
  - Phys: 1 peptides, 1 HLAs
  - ESM2: 1 peptides, 1 HLAs
  - Struct: 1 HLAs



### Initialize TCR-pHLA prediction system

We then initialize the TCR-pHLA prediction system based on the pre-trained pHLA prediction system. 

ESMFold will be automatically downloaded when running the code below for the first time. 

ESMFold takes around 1~2 seconds to compute the embeddings for one sequence on a single GPU.

It will take a while to run ESMFold. If **cache_save=True**, the computed embeddings will be saved to disk for faster future loading. It will facilitate faster training and inference when the dataset is large.

`StriMap_TCRpHLA` class inherits from `StriMap_pHLA` class, and takes the following additional parameters:

| Parameter | Default | Description |
|----------|---------|-------------|
| `pep_hla_system` | `None` | Pretrained and already-initialized `StriMap_pHLA` system used to reuse peptide–HLA encoders |
| `pep_hla_params` | `None` | List of peptide–HLA feature names to use from `StriMap_pHLA` |
| `device` | `"cuda:0"` | Device for computation |
| `model_save_path` | `"params/tcrphla_model.pt"` | Path to save the trained model |
| `tcr_dim` | `256` | Dimension of the TCR embedding |
| `pep_dim` | `256` | Dimension of the peptide embedding |
| `hla_dim` | `256` | Dimension of the HLA embedding |
| `bilinear_dim` | `256` | Dimension of the bilinear layer |
| `loss_fn` | `"focal"` | Loss function to use |
| `alpha` | `0.5` | Alpha parameter for focal loss |
| `gamma` | `2.0` | Gamma parameter for focal loss |
| `resample_negatives` | `False` | Whether to resample negative examples at each training epoch |
| `seed` | `1` | Random seed for reproducibility |
| `pos_weights` | `None` | Positive class weight for handling class imbalance |
| `use_struct` | `True` | Whether to include structural features |
| `cache_save` | `False` | Whether to save embedding caches |

In [7]:
tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    # load trained pHLA model parameters, if you have multiple, provide a list of paths
    pep_hla_params=['examples/params/phla_model.pt'], 
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
    # if you have already performed negative sampling in the input data, or you want to do it outside the model training, set this to False
    resample_negatives=False, 
    cache_save=True
)
tcr_phla_system.prepare_embeddings(df.reset_index(drop=True))

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 354 | TCRβ: 354 | peptides: 1 | HLAs: 1



Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 200.06it/s]
Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 206.32it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2621.44it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2560.63it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcra_esm2_layer33.pt





[ESM2] No new sequences for tcra, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcrb_esm2_layer33.pt


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:No new tcra sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:No new tcrb sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt


[ESM2] No new sequences for tcrb, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache


INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:No new pep sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.
✓ Embeddings prepared for TCRα/β, peptide, and HLA.


In [8]:
tcr_phla_system.pep_hla_params = ['examples/params/phla_model.pt']

### Train the TCR-pHLA prediction system with 5-fold cross-validation

`train_kfold` Default Parameters (TCR–pHLA)

| Parameter | Default | Description |
|----------|---------|-------------|
| `train_folds` | *(required)* | List of `(train_df, val_df)` tuples, one for each fold |
| `df_test` | `None` | Optional test dataset for evaluation after each epoch |
| `df_add` | `None` | Optional additional samples for training (used when `resample_negatives=True`) |
| `epochs` | `100` | Number of training epochs per fold |
| `batch_size` | `128` | Batch size used during training |
| `lr` | `1e-4` | Learning rate |
| `patience` | `8` | Early stopping patience |
| `num_workers` | `8` | Number of workers for data loading |

**Returns:**

- A list of training histories, one dictionary per fold, containing training and validation metrics across epochs.

In [9]:
all_histories = tcr_phla_system.train_kfold(
    train_folds=[(df_train_fold, df_val_fold) for df_train_fold, df_val_fold in zip(df_train_folds, df_val_folds)],
    epochs=5,
    batch_size=5, # make sure batch size is smaller than df_train_fold and df_val_fold sizes
)


Starting 5-Fold Cross-Validation Training (TCR-pHLA)

Training Fold 1/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  1.22it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.02it/s]


Epoch 1/5 | Train Loss: 0.0799 | Train AUC: 0.5638


Epoch 1/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 12.78it/s]


Epoch 1/5 | Val AUC: 0.7704 | Val PRC: 0.5141 | Val Loss: 0.0605
Validation improved → Saving model (Score=0.7704) to params/tcrphla_model.pt


Epoch 2/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.30it/s]


Epoch 2/5 | Train Loss: 0.0615 | Train AUC: 0.6519


Epoch 2/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.93it/s]


Epoch 2/5 | Val AUC: 0.7396 | Val PRC: 0.6690 | Val Loss: 0.0545
EarlyStopping counter: 1 out of 8


Epoch 3/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.35it/s]


Epoch 3/5 | Train Loss: 0.0553 | Train AUC: 0.7104


Epoch 3/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.55it/s]


Epoch 3/5 | Val AUC: 0.7720 | Val PRC: 0.6726 | Val Loss: 0.0533
Validation improved → Saving model (Score=0.7720) to params/tcrphla_model.pt


Epoch 4/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.96it/s]


Epoch 4/5 | Train Loss: 0.0489 | Train AUC: 0.7814


Epoch 4/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 14.21it/s]


Epoch 4/5 | Val AUC: 0.7874 | Val PRC: 0.6901 | Val Loss: 0.0528
Validation improved → Saving model (Score=0.7874) to params/tcrphla_model.pt


Epoch 5/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.35it/s]


Epoch 5/5 | Train Loss: 0.0477 | Train AUC: 0.8107


Epoch 5/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.90it/s]


Epoch 5/5 | Val AUC: 0.7812 | Val PRC: 0.7266 | Val Loss: 0.0530
EarlyStopping counter: 1 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 0 model to params/tcrphla_model_fold0.pt

Training Fold 2/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00, 16.74it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.08it/s]


Epoch 1/5 | Train Loss: 0.0699 | Train AUC: 0.5860


Epoch 1/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.29it/s]


Epoch 1/5 | Val AUC: 0.7935 | Val PRC: 0.6589 | Val Loss: 0.0556
Validation improved → Saving model (Score=0.7935) to params/tcrphla_model.pt


Epoch 2/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.32it/s]


Epoch 2/5 | Train Loss: 0.0525 | Train AUC: 0.7896


Epoch 2/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.22it/s]


Epoch 2/5 | Val AUC: 0.7180 | Val PRC: 0.5920 | Val Loss: 0.0549
EarlyStopping counter: 1 out of 8


Epoch 3/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.33it/s]


Epoch 3/5 | Train Loss: 0.0425 | Train AUC: 0.8582


Epoch 3/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 15.00it/s]


Epoch 3/5 | Val AUC: 0.6841 | Val PRC: 0.6077 | Val Loss: 0.0504
EarlyStopping counter: 2 out of 8


Epoch 4/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.39it/s]


Epoch 4/5 | Train Loss: 0.0431 | Train AUC: 0.8367


Epoch 4/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 15.26it/s]


Epoch 4/5 | Val AUC: 0.6872 | Val PRC: 0.5559 | Val Loss: 0.0498
EarlyStopping counter: 3 out of 8


Epoch 5/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.03it/s]


Epoch 5/5 | Train Loss: 0.0433 | Train AUC: 0.8199


Epoch 5/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.42it/s]


Epoch 5/5 | Val AUC: 0.7257 | Val PRC: 0.6164 | Val Loss: 0.0426
EarlyStopping counter: 4 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 1 model to params/tcrphla_model_fold1.pt

Training Fold 3/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00, 16.93it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.37it/s]


Epoch 1/5 | Train Loss: 0.0773 | Train AUC: 0.6076


Epoch 1/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 12.84it/s]


Epoch 1/5 | Val AUC: 0.8968 | Val PRC: 0.7603 | Val Loss: 0.0643
Validation improved → Saving model (Score=0.8968) to params/tcrphla_model.pt


Epoch 2/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.17it/s]


Epoch 2/5 | Train Loss: 0.0591 | Train AUC: 0.7370


Epoch 2/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 14.13it/s]


Epoch 2/5 | Val AUC: 0.8860 | Val PRC: 0.6398 | Val Loss: 0.0496
EarlyStopping counter: 1 out of 8


Epoch 3/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.25it/s]


Epoch 3/5 | Train Loss: 0.0533 | Train AUC: 0.7237


Epoch 3/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.13it/s]


Epoch 3/5 | Val AUC: 0.8444 | Val PRC: 0.7031 | Val Loss: 0.0567
EarlyStopping counter: 2 out of 8


Epoch 4/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.33it/s]


Epoch 4/5 | Train Loss: 0.0501 | Train AUC: 0.7547


Epoch 4/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.63it/s]


Epoch 4/5 | Val AUC: 0.9045 | Val PRC: 0.7437 | Val Loss: 0.0404
Validation improved → Saving model (Score=0.9045) to params/tcrphla_model.pt


Epoch 5/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.11it/s]


Epoch 5/5 | Train Loss: 0.0435 | Train AUC: 0.8337


Epoch 5/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.45it/s]


Epoch 5/5 | Val AUC: 0.8798 | Val PRC: 0.7673 | Val Loss: 0.0651
EarlyStopping counter: 1 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 2 model to params/tcrphla_model_fold2.pt

Training Fold 4/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00, 16.99it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.89it/s]


Epoch 1/5 | Train Loss: 0.0912 | Train AUC: 0.5571


Epoch 1/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.06it/s]


Epoch 1/5 | Val AUC: 0.8664 | Val PRC: 0.8204 | Val Loss: 0.0648
Validation improved → Saving model (Score=0.8664) to params/tcrphla_model.pt


Epoch 2/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.12it/s]


Epoch 2/5 | Train Loss: 0.0627 | Train AUC: 0.6885


Epoch 2/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 15.80it/s]


Epoch 2/5 | Val AUC: 0.8477 | Val PRC: 0.8365 | Val Loss: 0.0524
EarlyStopping counter: 1 out of 8


Epoch 3/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.60it/s]


Epoch 3/5 | Train Loss: 0.0562 | Train AUC: 0.7109


Epoch 3/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 15.12it/s]


Epoch 3/5 | Val AUC: 0.8649 | Val PRC: 0.8532 | Val Loss: 0.0605
EarlyStopping counter: 2 out of 8


Epoch 4/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.81it/s]


Epoch 4/5 | Train Loss: 0.0478 | Train AUC: 0.7875


Epoch 4/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.33it/s]


Epoch 4/5 | Val AUC: 0.8376 | Val PRC: 0.8541 | Val Loss: 0.0392
EarlyStopping counter: 3 out of 8


Epoch 5/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.79it/s]


Epoch 5/5 | Train Loss: 0.0465 | Train AUC: 0.8105


Epoch 5/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 16.48it/s]


Epoch 5/5 | Val AUC: 0.8477 | Val PRC: 0.8418 | Val Loss: 0.0462
EarlyStopping counter: 4 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 3 model to params/tcrphla_model_fold3.pt

Training Fold 5/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00, 17.58it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.37it/s]


Epoch 1/5 | Train Loss: 0.0756 | Train AUC: 0.6169


Epoch 1/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.66it/s]


Epoch 1/5 | Val AUC: 0.8305 | Val PRC: 0.7248 | Val Loss: 0.0634
Validation improved → Saving model (Score=0.8305) to params/tcrphla_model.pt


Epoch 2/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.87it/s]


Epoch 2/5 | Train Loss: 0.0539 | Train AUC: 0.8028


Epoch 2/5 [Val]: 100%|██████████| 14/14 [00:00<00:00, 14.19it/s]


Epoch 2/5 | Val AUC: 0.8276 | Val PRC: 0.7591 | Val Loss: 0.0448
EarlyStopping counter: 1 out of 8


Epoch 3/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.74it/s]


Epoch 3/5 | Train Loss: 0.0434 | Train AUC: 0.8714


Epoch 3/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.88it/s]


Epoch 3/5 | Val AUC: 0.7716 | Val PRC: 0.7120 | Val Loss: 0.0433
EarlyStopping counter: 2 out of 8


Epoch 4/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 11.15it/s]


Epoch 4/5 | Train Loss: 0.0470 | Train AUC: 0.7876


Epoch 4/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.59it/s]


Epoch 4/5 | Val AUC: 0.8319 | Val PRC: 0.7726 | Val Loss: 0.0406
Validation improved → Saving model (Score=0.8319) to params/tcrphla_model.pt


Epoch 5/5 [Train]: 100%|██████████| 57/57 [00:05<00:00, 10.96it/s]


Epoch 5/5 | Train Loss: 0.0445 | Train AUC: 0.8273


Epoch 5/5 [Val]: 100%|██████████| 14/14 [00:01<00:00, 13.99it/s]


Epoch 5/5 | Val AUC: 0.8980 | Val PRC: 0.8138 | Val Loss: 0.0403
Validation improved → Saving model (Score=0.8980) to params/tcrphla_model.pt
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 4 model to params/tcrphla_model_fold4.pt

✓ All 5 folds training completed (TCR-pHLA)

Cross-Validation Summary:
----------------------------------------------------------------------
Fold 0: Best Val AUC = 0.7874, Best Val PRC = 0.7266, (Epoch 4)
Fold 1: Best Val AUC = 0.7935, Best Val PRC = 0.6589, (Epoch 1)
Fold 2: Best Val AUC = 0.9045, Best Val PRC = 0.7673, (Epoch 4)
Fold 3: Best Val AUC = 0.8664, Best Val PRC = 0.8541, (Epoch 1)
Fold 4: Best Val AUC = 0.8980, Best Val PRC = 0.8138, (Epoch 5)
----------------------------------------------------------------------
Mean Val AUC: 0.8499 ± 0.0503



# Predict on the test set

For final models evaluated, please refer to the tcrphla_params.zip at [Zenodo repository](https://zenodo.org/records/18002170)

Run inference for TCR–pHLA specificity prediction.

`predict` Default Parameters

| Parameter | Default | Description |
|----------|---------|-------------|
| `df` | *(required)* | Input `DataFrame` containing TCR and peptide–HLA features |
| `batch_size` | `128` | Batch size used during inference |
| `return_probs` | `True` | If `True`, return predicted probabilities; otherwise return binary predictions |
| `use_kfold` | `False` | Whether to use an ensemble of K-fold models for prediction |
| `num_folds` | `None` | Number of folds used in training (required if `use_kfold=True`) |
| `ensemble_method` | `"mean"` | Method to aggregate K-fold predictions (`"mean"` or `"median"`) |
| `num_workers` | `8` | Number of workers for data loading |

#### Returns

```python
preds, fused_feat, attn = model.predict(...)
```
- `preds` (`np.ndarray` or `list`): Predicted probabilities (if `return_probs=True`) or binary labels (if `return_probs=False`).
- `fused_feat` (`list[torch.Tensor]`): Fused latent feature representations for each batch, extracted from the model’s joint TCR–pHLA embedding space. These features can be used for downstream analysis or visualization.
- `attn` (`dict[str, torch.Tensor]`): A dictionary of attention maps aggregated across the dataset. Each key corresponds to an attention component (e.g., TCR, peptide, or cross-attention), and values are padded and concatenated tensors aligned across samples.

In [10]:
import pandas as pd
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import preprocess_input_data

df_test = pd.read_csv('examples/tcrphla_test_set_example.csv')
df_test = preprocess_input_data(df_test)

tcrphla_fold = 5

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
pep_hla_system = StriMap_pHLA(device=device, cache_dir='cache')
pep_hla_system.prepare_embeddings(df_test)

tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    pep_hla_params=['examples/params/phla_model.pt'],
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
    cache_save=True
)

tcr_phla_system.prepare_embeddings(df_test)

# Predict on the test set using K-fold models
# pred, pep_feat, attn_dict = tcr_phla_system.predict(df_test, use_kfold=True, num_folds=tcrphla_fold)

# Alternatively, use evaluate to get performance metrics as well
pred, _ = tcr_phla_system.evaluate(df_test, use_kfold=True, num_folds=tcrphla_fold)
df_test['predicted_score'] = pred
print(df_test.head())

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 1 unique peptides
  - 1 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1777.25it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 2454.24it/s]
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache
✓ All embeddings prepared!
  - Phys: 1 peptides, 1 HLAs
  - ESM2: 1 peptides, 1 HLAs
  - Struct: 1 HLAs

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 91 | TCRβ: 91 | peptides: 1 | HLAs: 1



Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 313.52it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 329.40it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2686.93it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2465.79it/s]


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcra_esm2_layer33.pt
[ESM2] Found 88 new sequences → computing embeddings...


ESM2 update (tcra): 100%|██████████| 1/1 [00:01<00:00,  1.13s/it]


[ESM2] Updating cache with new sequences
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcrb_esm2_layer33.pt
[ESM2] Found 89 new sequences → computing embeddings...


ESM2 update (tcrb): 100%|██████████| 1/1 [00:01<00:00,  1.11s/it]


[ESM2] Updating cache with new sequences


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache


INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:Found new tcra sequences, embedding...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 88, valid seqs: 88, unique: 88
ESMfold Predicting structure...: 100%|██████████| 88/88 [01:29<00:00,  1.01s/it]
100%|██████████| 88/88 [01:29<00:00,  1.01s/it]
INFO:model:[DONE] OK: 88, Failed: 0
INFO:model:Updated and saved /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Found new tcrb sequences, embeddin

Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 89, valid seqs: 89, unique: 89
ESMfold Predicting structure...: 100%|██████████| 89/89 [01:31<00:00,  1.03s/it]
100%|██████████| 89/89 [01:31<00:00,  1.03s/it]
INFO:model:[DONE] OK: 89, Failed: 0
INFO:model:Updated and saved /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:No new pep sequences found
INFO:model:

✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.
✓ Embeddings prepared for TCRα/β, peptide, and HLA.
Preparing peptide-HLA features for prediction set...

Precomputing peptide-HLA features for 1 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00, 14.06it/s]

✓ Pretrained peptide-HLA features prepared.

Ensemble prediction using 5 TCR–pHLA models...
Ensemble method: mean
Loading model from params/tcrphla_model_fold0.pt...



Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:02<00:00,  2.50s/it]


Loading model from params/tcrphla_model_fold1.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:02<00:00,  2.72s/it]


Loading model from params/tcrphla_model_fold2.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:03<00:00,  3.05s/it]


Loading model from params/tcrphla_model_fold3.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:03<00:00,  3.27s/it]


Loading model from params/tcrphla_model_fold4.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:03<00:00,  3.27s/it]

✓ Ensemble prediction completed using 5 folds

Evaluation Results [K-Fold Ensemble]
tn=72, fp=0, fn=19, tp=0
AUC=0.6455 | PRC=0.5995 | ACC=0.7912 | MCC=0.0000 | F1=0.0000 | P=0.0000 | R=0.0000

                cdr3a             cdr3b             Va          Vb         Ja  \
0  CAVNPPGAGGTSYGKLTF   CASSDSTSDWETQYF    TRAV12-2*01  TRBV6-4*01  TRAJ52*01   
1          CAVTGNQFYF   CASSGDAAGAYGYTF    TRAV12-2*01    TRBV9*01  TRAJ49*01   
2      CAASSIQGAQKLVF  CASSPWDRANTGELFF  TRAV23/DV6*01  TRBV4-3*01  TRAJ54*01   
3         CAENEGQKLLF    CASSPNGGNTEAFF    TRAV13-2*01    TRBV9*01  TRAJ16*01   
4          CAVMDDKIIF     CASSQEVAYEQYF    TRAV12-2*01  TRBV3-1*01  TRAJ30*01   

           Jb    peptide          HLA  label  \
0  TRBJ2-5*01  LLWNGPMAV  HLA-A*02:01      0   
1  TRBJ1-2*01  LLWNGPMAV  HLA-A*02:01      1   
2  TRBJ2-2*01  LLWNGPMAV  HLA-A*02:01      1   
3  TRBJ1-1*01  LLWNGPMAV  HLA-A*02:01      0   
4  TRBJ2-7*01  LLWNGPMAV  HLA-A*02:01      1   

                              


