# Train the TCR-pHLA predictor with a K-fold strategy and make predictions on a test set

In [1]:
import pandas as pd
import numpy as np
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import negative_sampling_phla, preprocess_input_data
from sklearn.model_selection import StratifiedKFold

### Load data

We provided only a small subset here for demonstration purposes. For full dataset, please refer to the [Zenodo repository](https://zenodo.org/records/18002170)

In [2]:
import pandas as pd
df = pd.read_csv('examples/tcrphla_train_set_example.csv')

### Process input data

In [3]:
df = preprocess_input_data(df)

### 5-fold train/validation split

In [4]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
df_train_folds, df_val_folds = [], []

phla_fold = 5 # number of folds used in pHLA model training
tcrphla_fold = 5 # number of folds used in TCR-pHLA model training

skf = StratifiedKFold(n_splits=tcrphla_fold, shuffle=True, random_state=42)

df_train_folds = []
df_val_folds = []

for i, (train_idx, val_idx) in enumerate(skf.split(df, df["label"])):
    
    df_train_fold = df.iloc[train_idx].reset_index(drop=True)
    df_val_fold   = df.iloc[val_idx].reset_index(drop=True)

    df_val_fold_neg = negative_sampling_phla(df_val_fold, random_state=i)
    df_val_fold = pd.concat([df_val_fold, df_val_fold_neg], axis=0).reset_index(drop=True)

    df_train_folds.append(df_train_fold)
    df_val_folds.append(df_val_fold)

### Initialize pHLA prediction system

TCR-pHLA prediction system builds upon the pre-trained pHLA prediction system. Here we initialize the pHLA prediction system first.

In [5]:
pep_hla_system = StriMap_pHLA(device=device, cache_dir='cache')
pep_hla_system.prepare_embeddings(df.reset_index(drop=True))

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 17 unique peptides
  - 6 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 50.10it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1592.98it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt



INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache
✓ All embeddings prepared!
  - Phys: 17 peptides, 6 HLAs
  - ESM2: 2090 peptides, 114 HLAs
  - Struct: 6 HLAs



### Initialize TCR-pHLA prediction system

We then initialize the TCR-pHLA prediction system based on the pre-trained pHLA prediction system. 

ESMFold will be automatically downloaded when running the code below for the first time.

It will take a while to run ESMFold. If **cache_save=True**, the computed embeddings will be saved to disk for faster future loading. It will facilitate faster training and inference when the dataset is large.

`StriMap_TCRpHLA` class inherits from `StriMap_pHLA` class, and takes the following additional parameters:

| Parameter | Default | Description |
|----------|---------|-------------|
| `pep_hla_system` | `None` | Pretrained and already-initialized `StriMap_pHLA` system used to reuse peptide–HLA encoders |
| `pep_hla_params` | `None` | List of peptide–HLA feature names to use from `StriMap_pHLA` |
| `device` | `"cuda:0"` | Device for computation |
| `model_save_path` | `"params/tcrphla_model.pt"` | Path to save the trained model |
| `tcr_dim` | `256` | Dimension of the TCR embedding |
| `pep_dim` | `256` | Dimension of the peptide embedding |
| `hla_dim` | `256` | Dimension of the HLA embedding |
| `bilinear_dim` | `256` | Dimension of the bilinear layer |
| `loss_fn` | `"focal"` | Loss function to use |
| `alpha` | `0.5` | Alpha parameter for focal loss |
| `gamma` | `2.0` | Gamma parameter for focal loss |
| `resample_negatives` | `False` | Whether to resample negative examples at each training epoch |
| `seed` | `1` | Random seed for reproducibility |
| `pos_weights` | `None` | Positive class weight for handling class imbalance |
| `use_struct` | `True` | Whether to include structural features |
| `cache_save` | `False` | Whether to save embedding caches |

In [6]:
tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    pep_hla_params=[f'params/phla_model_fold{i}.pt' for i in range(phla_fold)], # load trained pHLA model parameters
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
    resample_negatives=True,
    cache_save=True
)
tcr_phla_system.prepare_embeddings(df.reset_index(drop=True))

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 285 | TCRβ: 288 | peptides: 17 | HLAs: 6



Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 281.89it/s]
Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 289.80it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2587.48it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 1755.67it/s]


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcra_esm2_layer33.pt
[ESM2] No new sequences for tcra, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcrb_esm2_layer33.pt
[ESM2] No new sequences for tcrb, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:No new tcra sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:No new tcrb sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:No new pep sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and 

[ESM2] No new sequences for hla, using existing cache
✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.
✓ Embeddings prepared for TCRα/β, peptide, and HLA.


### Train the TCR-pHLA prediction system with 5-fold cross-validation

`train_kfold` Default Parameters (TCR–pHLA)

| Parameter | Default | Description |
|----------|---------|-------------|
| `train_folds` | *(required)* | List of `(train_df, val_df)` tuples, one for each fold |
| `df_test` | `None` | Optional test dataset for evaluation after each epoch |
| `df_add` | `None` | Optional additional samples for training (used when `resample_negatives=True`) |
| `epochs` | `100` | Number of training epochs per fold |
| `batch_size` | `128` | Batch size used during training |
| `lr` | `1e-4` | Learning rate |
| `patience` | `8` | Early stopping patience |
| `num_workers` | `8` | Number of workers for data loading |

**Returns:**

- A list of training histories, one dictionary per fold, containing training and validation metrics across epochs.

In [7]:
all_histories = tcr_phla_system.train_kfold(
    train_folds=[(df_train_fold, df_val_fold) for df_train_fold, df_val_fold in zip(df_train_folds, df_val_folds)],
    epochs=5,
    batch_size=5, # make sure batch size is smaller than df_train_fold and df_val_fold sizes
)


Starting 5-Fold Cross-Validation Training (TCR-pHLA)

Training Fold 1/5


Preparing peptide-HLA features...

Precomputing peptide-HLA features for 17 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.55it/s]


Epoch 1/10 | Train Loss: 0.0821 | Train AUC: 0.5601


Epoch 1/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.88it/s]


Epoch 1/10 | Val AUC: 0.6126 | Val PRC: 0.0428 | Val Loss: 0.0687
Validation improved → Saving model (Score=0.6126) to params/tcrphla_model.pt


Epoch 2/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.90it/s]


Epoch 2/10 | Train Loss: 0.0676 | Train AUC: 0.5391


Epoch 2/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.52it/s]


Epoch 2/10 | Val AUC: 0.5367 | Val PRC: 0.0334 | Val Loss: 0.0285
EarlyStopping counter: 1 out of 8


Epoch 3/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.73it/s]


Epoch 3/10 | Train Loss: 0.0612 | Train AUC: 0.5498


Epoch 3/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.48it/s]


Epoch 3/10 | Val AUC: 0.5552 | Val PRC: 0.0372 | Val Loss: 0.0293
EarlyStopping counter: 2 out of 8


Epoch 4/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.37it/s]


Epoch 4/10 | Train Loss: 0.0646 | Train AUC: 0.4907


Epoch 4/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.65it/s]


Epoch 4/10 | Val AUC: 0.6184 | Val PRC: 0.0995 | Val Loss: 0.0359
Validation improved → Saving model (Score=0.6184) to params/tcrphla_model.pt


Epoch 5/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.93it/s]


Epoch 5/10 | Train Loss: 0.0648 | Train AUC: 0.4527


Epoch 5/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.54it/s]


Epoch 5/10 | Val AUC: 0.7468 | Val PRC: 0.0660 | Val Loss: 0.0230
Validation improved → Saving model (Score=0.7468) to params/tcrphla_model.pt


Epoch 6/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.21it/s]


Epoch 6/10 | Train Loss: 0.0625 | Train AUC: 0.4938


Epoch 6/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 28.99it/s]


Epoch 6/10 | Val AUC: 0.6347 | Val PRC: 0.0480 | Val Loss: 0.0265
EarlyStopping counter: 1 out of 8


Epoch 7/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.24it/s]


Epoch 7/10 | Train Loss: 0.0629 | Train AUC: 0.4692


Epoch 7/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.94it/s]


Epoch 7/10 | Val AUC: 0.6390 | Val PRC: 0.0462 | Val Loss: 0.0184
EarlyStopping counter: 2 out of 8


Epoch 8/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.93it/s]


Epoch 8/10 | Train Loss: 0.0616 | Train AUC: 0.5168


Epoch 8/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.83it/s]


Epoch 8/10 | Val AUC: 0.6644 | Val PRC: 0.1527 | Val Loss: 0.0226
EarlyStopping counter: 3 out of 8


Epoch 9/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.91it/s]


Epoch 9/10 | Train Loss: 0.0616 | Train AUC: 0.5316


Epoch 9/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.68it/s]


Epoch 9/10 | Val AUC: 0.5868 | Val PRC: 0.0586 | Val Loss: 0.0319
EarlyStopping counter: 4 out of 8


Epoch 10/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.04it/s]


Epoch 10/10 | Train Loss: 0.0631 | Train AUC: 0.5286


Epoch 10/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.38it/s]


Epoch 10/10 | Val AUC: 0.5526 | Val PRC: 0.0386 | Val Loss: 0.0357
EarlyStopping counter: 5 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 0 model to params/tcrphla_model_fold0.pt

Training Fold 2/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 17 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  3.11it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.58it/s]


Epoch 1/10 | Train Loss: 0.0728 | Train AUC: 0.5040


Epoch 1/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.39it/s]


Epoch 1/10 | Val AUC: 0.5363 | Val PRC: 0.0447 | Val Loss: 0.0392
Validation improved → Saving model (Score=0.5363) to params/tcrphla_model.pt


Epoch 2/10 [Train]: 100%|██████████| 45/45 [00:04<00:00, 11.03it/s]


Epoch 2/10 | Train Loss: 0.0648 | Train AUC: 0.5265


Epoch 2/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.19it/s]


Epoch 2/10 | Val AUC: 0.4699 | Val PRC: 0.0413 | Val Loss: 0.0490
EarlyStopping counter: 1 out of 8


Epoch 3/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.82it/s]


Epoch 3/10 | Train Loss: 0.0644 | Train AUC: 0.4728


Epoch 3/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.14it/s]


Epoch 3/10 | Val AUC: 0.6100 | Val PRC: 0.0948 | Val Loss: 0.0228
Validation improved → Saving model (Score=0.6100) to params/tcrphla_model.pt


Epoch 4/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.96it/s]


Epoch 4/10 | Train Loss: 0.0647 | Train AUC: 0.4420


Epoch 4/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.43it/s]


Epoch 4/10 | Val AUC: 0.4526 | Val PRC: 0.0343 | Val Loss: 0.0285
EarlyStopping counter: 1 out of 8


Epoch 5/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.97it/s]


Epoch 5/10 | Train Loss: 0.0642 | Train AUC: 0.4627


Epoch 5/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.82it/s]


Epoch 5/10 | Val AUC: 0.4666 | Val PRC: 0.0556 | Val Loss: 0.0467
EarlyStopping counter: 2 out of 8


Epoch 6/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.96it/s]


Epoch 6/10 | Train Loss: 0.0656 | Train AUC: 0.4101


Epoch 6/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.55it/s]


Epoch 6/10 | Val AUC: 0.4438 | Val PRC: 0.0389 | Val Loss: 0.0358
EarlyStopping counter: 3 out of 8


Epoch 7/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.07it/s]


Epoch 7/10 | Train Loss: 0.0633 | Train AUC: 0.4931


Epoch 7/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.01it/s]


Epoch 7/10 | Val AUC: 0.4379 | Val PRC: 0.0321 | Val Loss: 0.0491
EarlyStopping counter: 4 out of 8


Epoch 8/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.84it/s]


Epoch 8/10 | Train Loss: 0.0602 | Train AUC: 0.5877


Epoch 8/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.98it/s]


Epoch 8/10 | Val AUC: 0.4578 | Val PRC: 0.0319 | Val Loss: 0.0376
EarlyStopping counter: 5 out of 8


Epoch 9/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.14it/s]


Epoch 9/10 | Train Loss: 0.0626 | Train AUC: 0.5153


Epoch 9/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 32.13it/s]


Epoch 9/10 | Val AUC: 0.6729 | Val PRC: 0.0873 | Val Loss: 0.0587
Validation improved → Saving model (Score=0.6729) to params/tcrphla_model.pt


Epoch 10/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 12.17it/s]


Epoch 10/10 | Train Loss: 0.0629 | Train AUC: 0.4918


Epoch 10/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.31it/s]


Epoch 10/10 | Val AUC: 0.6523 | Val PRC: 0.0452 | Val Loss: 0.0217
EarlyStopping counter: 1 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 1 model to params/tcrphla_model_fold1.pt

Training Fold 3/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 17 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  3.09it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 12.81it/s]


Epoch 1/10 | Train Loss: 0.0845 | Train AUC: 0.4903


Epoch 1/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 32.01it/s]


Epoch 1/10 | Val AUC: 0.5574 | Val PRC: 0.0355 | Val Loss: 0.0585
Validation improved → Saving model (Score=0.5574) to params/tcrphla_model.pt


Epoch 2/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 12.14it/s]


Epoch 2/10 | Train Loss: 0.0681 | Train AUC: 0.5379


Epoch 2/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.24it/s]


Epoch 2/10 | Val AUC: 0.4729 | Val PRC: 0.0339 | Val Loss: 0.0739
EarlyStopping counter: 1 out of 8


Epoch 3/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.73it/s]


Epoch 3/10 | Train Loss: 0.0656 | Train AUC: 0.4695


Epoch 3/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.15it/s]


Epoch 3/10 | Val AUC: 0.4235 | Val PRC: 0.0580 | Val Loss: 0.0611
EarlyStopping counter: 2 out of 8


Epoch 4/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.43it/s]


Epoch 4/10 | Train Loss: 0.0614 | Train AUC: 0.5294


Epoch 4/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.48it/s]


Epoch 4/10 | Val AUC: 0.4474 | Val PRC: 0.0467 | Val Loss: 0.0397
EarlyStopping counter: 3 out of 8


Epoch 5/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.78it/s]


Epoch 5/10 | Train Loss: 0.0626 | Train AUC: 0.4849


Epoch 5/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.93it/s]


Epoch 5/10 | Val AUC: 0.5909 | Val PRC: 0.0547 | Val Loss: 0.0600
Validation improved → Saving model (Score=0.5909) to params/tcrphla_model.pt


Epoch 6/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.90it/s]


Epoch 6/10 | Train Loss: 0.0614 | Train AUC: 0.4900


Epoch 6/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.69it/s]


Epoch 6/10 | Val AUC: 0.6062 | Val PRC: 0.0440 | Val Loss: 0.0234
Validation improved → Saving model (Score=0.6062) to params/tcrphla_model.pt


Epoch 7/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.60it/s]


Epoch 7/10 | Train Loss: 0.0651 | Train AUC: 0.4307


Epoch 7/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 28.78it/s]


Epoch 7/10 | Val AUC: 0.5644 | Val PRC: 0.0538 | Val Loss: 0.0286
EarlyStopping counter: 1 out of 8


Epoch 8/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.31it/s]


Epoch 8/10 | Train Loss: 0.0664 | Train AUC: 0.4090


Epoch 8/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.09it/s]


Epoch 8/10 | Val AUC: 0.5656 | Val PRC: 0.0398 | Val Loss: 0.0405
EarlyStopping counter: 2 out of 8


Epoch 9/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.50it/s]


Epoch 9/10 | Train Loss: 0.0626 | Train AUC: 0.4983


Epoch 9/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.00it/s]


Epoch 9/10 | Val AUC: 0.5291 | Val PRC: 0.0375 | Val Loss: 0.0514
EarlyStopping counter: 3 out of 8


Epoch 10/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.86it/s]


Epoch 10/10 | Train Loss: 0.0613 | Train AUC: 0.5290


Epoch 10/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 29.04it/s]


Epoch 10/10 | Val AUC: 0.4812 | Val PRC: 0.0546 | Val Loss: 0.0239
EarlyStopping counter: 4 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 2 model to params/tcrphla_model_fold2.pt

Training Fold 4/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 17 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  3.01it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.87it/s]


Epoch 1/10 | Train Loss: 0.0921 | Train AUC: 0.5974


Epoch 1/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.38it/s]


Epoch 1/10 | Val AUC: 0.6882 | Val PRC: 0.0460 | Val Loss: 0.0930
Validation improved → Saving model (Score=0.6882) to params/tcrphla_model.pt


Epoch 2/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 12.35it/s]


Epoch 2/10 | Train Loss: 0.0743 | Train AUC: 0.4265


Epoch 2/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.92it/s]


Epoch 2/10 | Val AUC: 0.4691 | Val PRC: 0.0675 | Val Loss: 0.0641
EarlyStopping counter: 1 out of 8


Epoch 3/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 12.08it/s]


Epoch 3/10 | Train Loss: 0.0652 | Train AUC: 0.4872


Epoch 3/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.33it/s]


Epoch 3/10 | Val AUC: 0.4435 | Val PRC: 0.0445 | Val Loss: 0.0461
EarlyStopping counter: 2 out of 8


Epoch 4/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.53it/s]


Epoch 4/10 | Train Loss: 0.0651 | Train AUC: 0.5094


Epoch 4/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 28.79it/s]


Epoch 4/10 | Val AUC: 0.6579 | Val PRC: 0.0542 | Val Loss: 0.0290
EarlyStopping counter: 3 out of 8


Epoch 5/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.77it/s]


Epoch 5/10 | Train Loss: 0.0664 | Train AUC: 0.4468


Epoch 5/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.00it/s]


Epoch 5/10 | Val AUC: 0.4453 | Val PRC: 0.0533 | Val Loss: 0.0434
EarlyStopping counter: 4 out of 8


Epoch 6/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.33it/s]


Epoch 6/10 | Train Loss: 0.0676 | Train AUC: 0.3902


Epoch 6/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.78it/s]


Epoch 6/10 | Val AUC: 0.4441 | Val PRC: 0.0407 | Val Loss: 0.0381
EarlyStopping counter: 5 out of 8


Epoch 7/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 12.09it/s]


Epoch 7/10 | Train Loss: 0.0625 | Train AUC: 0.5471


Epoch 7/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 30.97it/s]


Epoch 7/10 | Val AUC: 0.5647 | Val PRC: 0.0675 | Val Loss: 0.0336
EarlyStopping counter: 6 out of 8


Epoch 8/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.66it/s]


Epoch 8/10 | Train Loss: 0.0614 | Train AUC: 0.5159


Epoch 8/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.66it/s]


Epoch 8/10 | Val AUC: 0.4368 | Val PRC: 0.0283 | Val Loss: 0.0201
EarlyStopping counter: 7 out of 8


Epoch 9/10 [Train]: 100%|██████████| 44/44 [00:03<00:00, 11.92it/s]


Epoch 9/10 | Train Loss: 0.0632 | Train AUC: 0.4697


Epoch 9/10 [Val]: 100%|██████████| 70/70 [00:02<00:00, 31.55it/s]


Epoch 9/10 | Val AUC: 0.4341 | Val PRC: 0.0315 | Val Loss: 0.0202
EarlyStopping counter: 8 out of 8
Early stopping at epoch 9
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 3 model to params/tcrphla_model_fold3.pt

Training Fold 5/5
Preparing peptide-HLA features...

Precomputing peptide-HLA features for 17 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  3.05it/s]


✓ Pretrained peptide-HLA features prepared.

Start training TCR-pHLA model...


Epoch 1/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.63it/s]


Epoch 1/10 | Train Loss: 0.0807 | Train AUC: 0.4878


Epoch 1/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 30.41it/s]


Epoch 1/10 | Val AUC: 0.4200 | Val PRC: 0.0302 | Val Loss: 0.0761
Validation improved → Saving model (Score=0.4200) to params/tcrphla_model.pt


Epoch 2/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.64it/s]


Epoch 2/10 | Train Loss: 0.0687 | Train AUC: 0.4869


Epoch 2/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 29.09it/s]


Epoch 2/10 | Val AUC: 0.4259 | Val PRC: 0.0322 | Val Loss: 0.0523
Validation improved → Saving model (Score=0.4259) to params/tcrphla_model.pt


Epoch 3/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.85it/s]


Epoch 3/10 | Train Loss: 0.0642 | Train AUC: 0.4674


Epoch 3/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 30.57it/s]


Epoch 3/10 | Val AUC: 0.4630 | Val PRC: 0.0256 | Val Loss: 0.0324
Validation improved → Saving model (Score=0.4630) to params/tcrphla_model.pt


Epoch 4/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.95it/s]


Epoch 4/10 | Train Loss: 0.0629 | Train AUC: 0.4543


Epoch 4/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 29.96it/s]


Epoch 4/10 | Val AUC: 0.5820 | Val PRC: 0.0463 | Val Loss: 0.0213
Validation improved → Saving model (Score=0.5820) to params/tcrphla_model.pt


Epoch 5/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.88it/s]


Epoch 5/10 | Train Loss: 0.0642 | Train AUC: 0.4610


Epoch 5/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 29.49it/s]


Epoch 5/10 | Val AUC: 0.5387 | Val PRC: 0.0376 | Val Loss: 0.0394
EarlyStopping counter: 1 out of 8


Epoch 6/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.45it/s]


Epoch 6/10 | Train Loss: 0.0606 | Train AUC: 0.4985


Epoch 6/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 30.06it/s]


Epoch 6/10 | Val AUC: 0.4296 | Val PRC: 0.0274 | Val Loss: 0.0475
EarlyStopping counter: 2 out of 8


Epoch 7/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.93it/s]


Epoch 7/10 | Train Loss: 0.0618 | Train AUC: 0.5370


Epoch 7/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 30.29it/s]


Epoch 7/10 | Val AUC: 0.4663 | Val PRC: 0.0292 | Val Loss: 0.0356
EarlyStopping counter: 3 out of 8


Epoch 8/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.87it/s]


Epoch 8/10 | Train Loss: 0.0642 | Train AUC: 0.4665


Epoch 8/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 30.73it/s]


Epoch 8/10 | Val AUC: 0.4101 | Val PRC: 0.0267 | Val Loss: 0.0329
EarlyStopping counter: 4 out of 8


Epoch 9/10 [Train]: 100%|██████████| 45/45 [00:04<00:00, 11.15it/s]


Epoch 9/10 | Train Loss: 0.0627 | Train AUC: 0.4885


Epoch 9/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 31.01it/s]


Epoch 9/10 | Val AUC: 0.3929 | Val PRC: 0.0265 | Val Loss: 0.0246
EarlyStopping counter: 5 out of 8


Epoch 10/10 [Train]: 100%|██████████| 45/45 [00:03<00:00, 11.72it/s]


Epoch 10/10 | Train Loss: 0.0628 | Train AUC: 0.5160


Epoch 10/10 [Val]: 100%|██████████| 69/69 [00:02<00:00, 29.47it/s]


Epoch 10/10 | Val AUC: 0.5314 | Val PRC: 0.0432 | Val Loss: 0.0468
EarlyStopping counter: 6 out of 8
✓ Training finished. Best model loaded from params/tcrphla_model.pt
✓ Saved fold 4 model to params/tcrphla_model_fold4.pt

✓ All 5 folds training completed (TCR-pHLA)

Cross-Validation Summary:
----------------------------------------------------------------------
Fold 0: Best Val AUC = 0.7468, Best Val PRC = 0.1527, (Epoch 5)
Fold 1: Best Val AUC = 0.6729, Best Val PRC = 0.0948, (Epoch 9)
Fold 2: Best Val AUC = 0.6062, Best Val PRC = 0.0580, (Epoch 6)
Fold 3: Best Val AUC = 0.6882, Best Val PRC = 0.0675, (Epoch 1)
Fold 4: Best Val AUC = 0.5820, Best Val PRC = 0.0463, (Epoch 4)
----------------------------------------------------------------------
Mean Val AUC: 0.6592 ± 0.0591



# Predict on the test set

For final models evaluated, please refer to the tcrphla_params.zip at [Zenodo repository](https://zenodo.org/records/18002170)

Run inference for TCR–pHLA specificity prediction.

`predict` Default Parameters

| Parameter | Default | Description |
|----------|---------|-------------|
| `df` | *(required)* | Input `DataFrame` containing TCR and peptide–HLA features |
| `batch_size` | `128` | Batch size used during inference |
| `return_probs` | `True` | If `True`, return predicted probabilities; otherwise return binary predictions |
| `use_kfold` | `False` | Whether to use an ensemble of K-fold models for prediction |
| `num_folds` | `None` | Number of folds used in training (required if `use_kfold=True`) |
| `ensemble_method` | `"mean"` | Method to aggregate K-fold predictions (`"mean"` or `"median"`) |
| `num_workers` | `8` | Number of workers for data loading |

#### Returns

```python
preds, fused_feat, attn = model.predict(...)
```
- `preds` (`np.ndarray` or `list`): Predicted probabilities (if `return_probs=True`) or binary labels (if `return_probs=False`).
- `fused_feat` (`list[torch.Tensor]`): Fused latent feature representations for each batch, extracted from the model’s joint TCR–pHLA embedding space. These features can be used for downstream analysis or visualization.
- `attn` (`dict[str, torch.Tensor]`): A dictionary of attention maps aggregated across the dataset. Each key corresponds to an attention component (e.g., TCR, peptide, or cross-attention), and values are padded and concatenated tensors aligned across samples.

In [8]:
import pandas as pd
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import preprocess_input_data

df_test = pd.read_csv('examples/tcrphla_test_set_example.csv')
df_test = preprocess_input_data(df_test)

tcrphla_fold = 5
phla_fold = 5

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
pep_hla_system = StriMap_pHLA(device=device, cache_dir='cache')
pep_hla_system.prepare_embeddings(df_test)

tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    pep_hla_params=[f'params/phla_model_fold{i}.pt' for i in range(phla_fold)],
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
)

tcr_phla_system.prepare_embeddings(df_test)

# Predict on the test set using K-fold models
# pred, pep_feat, attn_dict = tcr_phla_system.predict(df_test, use_kfold=True, num_folds=tcrphla_fold)

# Alternatively, use evaluate to get performance metrics as well
pred, _ = tcr_phla_system.evaluate(df_test, use_kfold=True, num_folds=tcrphla_fold)
df_test['predicted_score'] = pred
print(df_test.head())

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 16 unique peptides
  - 6 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1672.37it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1780.26it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt



INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt


[ESM2] No new sequences for hla, using existing cache


INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ All embeddings prepared!
  - Phys: 16 peptides, 6 HLAs
  - ESM2: 2090 peptides, 114 HLAs
  - Struct: 6 HLAs

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 73 | TCRβ: 72 | peptides: 16 | HLAs: 6



Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 421.07it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 494.03it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2493.64it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 1743.27it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcra_esm2_layer33.pt





[ESM2] Found 67 new sequences → computing embeddings...


ESM2 update (tcra): 100%|██████████| 1/1 [00:00<00:00,  1.26it/s]


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/tcrb_esm2_layer33.pt
[ESM2] Found 66 new sequences → computing embeddings...


ESM2 update (tcrb): 100%|██████████| 1/1 [00:00<00:00,  1.26it/s]


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt
INFO:model:Found new tcra sequences, embedding...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 67, valid seqs: 67, unique: 67
ESMfold Predicting structure...: 100%|██████████| 67/67 [01:07<00:00,  1.01s/it]
100%|██████████| 67/67 [01:07<00:00,  1.01s/it]
INFO:model:[DONE] OK: 67, Failed: 0
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcrb_coord_dict.pt
INFO:model:Found new tcrb sequences, embedding...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 66, valid seqs: 66, unique: 66
ESMfold Predicting structure...: 100%|██████████| 66/66 [01:06<00:00,  1.01s/it]
100%|██████████| 66/66 [01:06<00:00,  1.01s/it]
INFO:model:[DONE] OK: 66, Failed: 0
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/pep_coord_dict.pt
INFO:model:No new pep sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.
✓ Embeddings prepared for TCRα/β, peptide, and HLA.
Preparing peptide-HLA features for prediction set...

Precomputing peptide-HLA features for 16 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  3.03it/s]


✓ Pretrained peptide-HLA features prepared.

Ensemble prediction using 5 TCR–pHLA models...
Ensemble method: mean
Loading model from params/tcrphla_model_fold0.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


Loading model from params/tcrphla_model_fold1.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]


Loading model from params/tcrphla_model_fold2.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.06s/it]


Loading model from params/tcrphla_model_fold3.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]


Loading model from params/tcrphla_model_fold4.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.07s/it]

✓ Ensemble prediction completed using 5 folds

Evaluation Results [K-Fold Ensemble]
tn=66, fp=0, fn=8, tp=0
AUC=0.4413 | PRC=0.1425 | ACC=0.8919 | MCC=0.0000 | F1=0.0000 | P=0.0000 | R=0.0000

              cdr3a           cdr3b             Va         Ja           Vb  \
0   CALRSPQAAGNKLTF   CASSLSGTYEQYF      TRAV19*01  TRAJ17*01   TRBV6-2*01   
1       CVVNENNDMRF   CASGDENTGELFF    TRAV12-1*01  TRAJ43*01   TRBV5-1*01   
2     FSLSGGSRGNLLF  CASNAGGGVETQYF      TRAV25*01  TRAJ24*01     TRBV2*01   
3       CVVNNNNDMRF   CASGEANTGELFF    TRAV12-1*01  TRAJ43*01  TRBV12-3*01   
4  CAMREAGWEGAQKLVF  CASRLTGGDQPQHF  TRAV14/DV4*01  TRAJ54*01    TRBV27*01   

           Jb    peptide      HLA  label  \
0  TRBJ2-7*01  LTDEMIAQY  A*01:01      0   
1  TRBJ2-2*01  GILGFVFTL  A*02:01      0   
2  TRBJ2-5*01  RAQAPPPSW  B*57:01      1   
3  TRBJ2-2*01  GPRLGVRAT  B*07:02      0   
4  TRBJ1-5*01  NLVPMVATV  A*02:01      0   

                                                tcra  \
0  AQKVTQAQTEISVV


