# Train pHLA predictor with a K-fold strategy and make predictions on a test set

In [1]:
import pandas as pd
from main import StriMap_pHLA
from utils import load_train_data
from sklearn.model_selection import train_test_split

### Load data

We provided only a small subset here for demonstration purposes. For full dataset, please refer to the [Zenodo repository](https://zenodo.org/records/18002170)

In [2]:
df = pd.read_csv("examples/phla_train_set_example.csv")

### 5-fold train/val splits

In [3]:
num_fold = 5
train_folds, val_folds = zip(*[
    train_test_split(
        df, test_size=0.1, random_state=i, stratify=df["label"]
    )
    for i in range(num_fold)
])

train_folds = [d.reset_index(drop=True) for d in train_folds]
val_folds   = [d.reset_index(drop=True) for d in val_folds]

### HLA full sequence mapping

In [4]:
train_folds, val_folds = load_train_data(
    df_train_list=list(train_folds),
    df_val_list=list(val_folds),
    hla_dict_path="HLA_dict.npy",
)

Loading training and validation data...


### Initialize StriMap pHLA model

In [5]:
strimap = StriMap_pHLA(
    device="cuda:0",
    model_save_path="params/phla_model.pt",
    cache_dir="cache",
)

# Prepare embeddings (cached for faster training)
strimap.prepare_embeddings(
    pd.concat([*train_folds, *val_folds], ignore_index=True), 
    force_recompute=False,
)

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 1659 unique peptides
  - 113 unique HLAs



Phys encoding: 100%|██████████| 7/7 [00:00<00:00, 285.16it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 213.17it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt



INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt


[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache


INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ All embeddings prepared!
  - Phys: 1659 peptides, 113 HLAs
  - ESM2: 1659 peptides, 113 HLAs
  - Struct: 113 HLAs



### K-fold training

In [6]:
all_history = strimap.train_kfold(
    train_folds=list(zip(train_folds, val_folds)),
    epochs=10, # Reduced for demonstration; use 100 for actual training
)


Starting 5-Fold Cross-Validation Training

Training Fold 1/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 0]...


                                                                                

[Fold 0] Epoch [1/10] | Train Loss: 0.0747 | Val Loss: 0.0685 | Val AUC: 0.7839 | Val PRC: 0.4513
Validation improved → Saving model (Score=0.4513) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [2/10] | Train Loss: 0.0671 | Val Loss: 0.0627 | Val AUC: 0.7922 | Val PRC: 0.5881
Validation improved → Saving model (Score=0.5881) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [3/10] | Train Loss: 0.0593 | Val Loss: 0.0539 | Val AUC: 0.7653 | Val PRC: 0.4850
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 0] Epoch [4/10] | Train Loss: 0.0552 | Val Loss: 0.0514 | Val AUC: 0.7653 | Val PRC: 0.6220
Validation improved → Saving model (Score=0.6220) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [5/10] | Train Loss: 0.0519 | Val Loss: 0.0476 | Val AUC: 0.8085 | Val PRC: 0.6642
Validation improved → Saving model (Score=0.6642) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [6/10] | Train Loss: 0.0493 | Val Loss: 0.0524 | Val AUC: 0.7803 | Val PRC: 0.6746
Validation improved → Saving model (Score=0.6746) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [7/10] | Train Loss: 0.0472 | Val Loss: 0.0481 | Val AUC: 0.7961 | Val PRC: 0.6848
Validation improved → Saving model (Score=0.6848) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [8/10] | Train Loss: 0.0457 | Val Loss: 0.0439 | Val AUC: 0.8230 | Val PRC: 0.7064
Validation improved → Saving model (Score=0.7064) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [9/10] | Train Loss: 0.0442 | Val Loss: 0.0465 | Val AUC: 0.8295 | Val PRC: 0.6903
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 0] Epoch [10/10] | Train Loss: 0.0424 | Val Loss: 0.0443 | Val AUC: 0.8385 | Val PRC: 0.6730
EarlyStopping counter: 2 out of 5

[Fold 0] Loading best model from params/phla_model_fold0.pt...
✓ Training completed for Fold 0!

Training Fold 2/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 1]...


                                                                                

[Fold 1] Epoch [1/10] | Train Loss: 0.0864 | Val Loss: 0.0744 | Val AUC: 0.7200 | Val PRC: 0.3821
Validation improved → Saving model (Score=0.3821) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [2/10] | Train Loss: 0.0709 | Val Loss: 0.0637 | Val AUC: 0.7932 | Val PRC: 0.4888
Validation improved → Saving model (Score=0.4888) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [3/10] | Train Loss: 0.0608 | Val Loss: 0.0546 | Val AUC: 0.7679 | Val PRC: 0.4088
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 1] Epoch [4/10] | Train Loss: 0.0554 | Val Loss: 0.0490 | Val AUC: 0.8038 | Val PRC: 0.4816
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 1] Epoch [5/10] | Train Loss: 0.0503 | Val Loss: 0.0453 | Val AUC: 0.8408 | Val PRC: 0.6545
Validation improved → Saving model (Score=0.6545) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [6/10] | Train Loss: 0.0479 | Val Loss: 0.0438 | Val AUC: 0.8403 | Val PRC: 0.6836
Validation improved → Saving model (Score=0.6836) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [7/10] | Train Loss: 0.0473 | Val Loss: 0.0429 | Val AUC: 0.8429 | Val PRC: 0.6778
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 1] Epoch [8/10] | Train Loss: 0.0455 | Val Loss: 0.0431 | Val AUC: 0.8313 | Val PRC: 0.6716
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 1] Epoch [9/10] | Train Loss: 0.0433 | Val Loss: 0.0426 | Val AUC: 0.8354 | Val PRC: 0.6702
EarlyStopping counter: 3 out of 5


                                                                                

[Fold 1] Epoch [10/10] | Train Loss: 0.0418 | Val Loss: 0.0405 | Val AUC: 0.8678 | Val PRC: 0.6987
Validation improved → Saving model (Score=0.6987) to params/phla_model_fold1.pt

[Fold 1] Loading best model from params/phla_model_fold1.pt...
✓ Training completed for Fold 1!

Training Fold 3/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 2]...


                                                                                

[Fold 2] Epoch [1/10] | Train Loss: 0.0713 | Val Loss: 0.0823 | Val AUC: 0.6398 | Val PRC: 0.2788
Validation improved → Saving model (Score=0.2788) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [2/10] | Train Loss: 0.0624 | Val Loss: 0.0637 | Val AUC: 0.6925 | Val PRC: 0.3540
Validation improved → Saving model (Score=0.3540) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [3/10] | Train Loss: 0.0562 | Val Loss: 0.0576 | Val AUC: 0.6742 | Val PRC: 0.3209
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 2] Epoch [4/10] | Train Loss: 0.0528 | Val Loss: 0.0567 | Val AUC: 0.6796 | Val PRC: 0.3393
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 2] Epoch [5/10] | Train Loss: 0.0480 | Val Loss: 0.0573 | Val AUC: 0.6822 | Val PRC: 0.3425
EarlyStopping counter: 3 out of 5


                                                                                

[Fold 2] Epoch [6/10] | Train Loss: 0.0467 | Val Loss: 0.0583 | Val AUC: 0.6900 | Val PRC: 0.3633
Validation improved → Saving model (Score=0.3633) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [7/10] | Train Loss: 0.0445 | Val Loss: 0.0606 | Val AUC: 0.7063 | Val PRC: 0.3573
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 2] Epoch [8/10] | Train Loss: 0.0415 | Val Loss: 0.0653 | Val AUC: 0.7037 | Val PRC: 0.3749
Validation improved → Saving model (Score=0.3749) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [9/10] | Train Loss: 0.0412 | Val Loss: 0.0603 | Val AUC: 0.7208 | Val PRC: 0.4012
Validation improved → Saving model (Score=0.4012) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [10/10] | Train Loss: 0.0400 | Val Loss: 0.0607 | Val AUC: 0.7042 | Val PRC: 0.4003
EarlyStopping counter: 1 out of 5

[Fold 2] Loading best model from params/phla_model_fold2.pt...
✓ Training completed for Fold 2!

Training Fold 4/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 3]...


                                                                                

[Fold 3] Epoch [1/10] | Train Loss: 0.0684 | Val Loss: 0.0686 | Val AUC: 0.7086 | Val PRC: 0.3798
Validation improved → Saving model (Score=0.3798) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [2/10] | Train Loss: 0.0628 | Val Loss: 0.0562 | Val AUC: 0.7505 | Val PRC: 0.3118
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [3/10] | Train Loss: 0.0575 | Val Loss: 0.0535 | Val AUC: 0.7409 | Val PRC: 0.3129
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 3] Epoch [4/10] | Train Loss: 0.0541 | Val Loss: 0.0540 | Val AUC: 0.7562 | Val PRC: 0.3934
Validation improved → Saving model (Score=0.3934) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [5/10] | Train Loss: 0.0513 | Val Loss: 0.0532 | Val AUC: 0.7570 | Val PRC: 0.4507
Validation improved → Saving model (Score=0.4507) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [6/10] | Train Loss: 0.0487 | Val Loss: 0.0505 | Val AUC: 0.7893 | Val PRC: 0.4797
Validation improved → Saving model (Score=0.4797) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [7/10] | Train Loss: 0.0462 | Val Loss: 0.0493 | Val AUC: 0.7836 | Val PRC: 0.4066
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [8/10] | Train Loss: 0.0439 | Val Loss: 0.0499 | Val AUC: 0.8098 | Val PRC: 0.4995
Validation improved → Saving model (Score=0.4995) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [9/10] | Train Loss: 0.0407 | Val Loss: 0.0528 | Val AUC: 0.7743 | Val PRC: 0.4613
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [10/10] | Train Loss: 0.0414 | Val Loss: 0.0536 | Val AUC: 0.7704 | Val PRC: 0.4503
EarlyStopping counter: 2 out of 5

[Fold 3] Loading best model from params/phla_model_fold3.pt...
✓ Training completed for Fold 3!

Training Fold 5/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 4]...


                                                                                

[Fold 4] Epoch [1/10] | Train Loss: 0.0949 | Val Loss: 0.0649 | Val AUC: 0.7381 | Val PRC: 0.4160
Validation improved → Saving model (Score=0.4160) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [2/10] | Train Loss: 0.0810 | Val Loss: 0.0613 | Val AUC: 0.7717 | Val PRC: 0.5414
Validation improved → Saving model (Score=0.5414) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [3/10] | Train Loss: 0.0696 | Val Loss: 0.0580 | Val AUC: 0.7748 | Val PRC: 0.4739
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 4] Epoch [4/10] | Train Loss: 0.0610 | Val Loss: 0.0530 | Val AUC: 0.7694 | Val PRC: 0.4669
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 4] Epoch [5/10] | Train Loss: 0.0546 | Val Loss: 0.0500 | Val AUC: 0.7984 | Val PRC: 0.5465
Validation improved → Saving model (Score=0.5465) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [6/10] | Train Loss: 0.0504 | Val Loss: 0.0476 | Val AUC: 0.8100 | Val PRC: 0.6547
Validation improved → Saving model (Score=0.6547) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [7/10] | Train Loss: 0.0480 | Val Loss: 0.0455 | Val AUC: 0.8235 | Val PRC: 0.6892
Validation improved → Saving model (Score=0.6892) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [8/10] | Train Loss: 0.0437 | Val Loss: 0.0430 | Val AUC: 0.8372 | Val PRC: 0.6735
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 4] Epoch [9/10] | Train Loss: 0.0433 | Val Loss: 0.0435 | Val AUC: 0.8341 | Val PRC: 0.6749
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 4] Epoch [10/10] | Train Loss: 0.0423 | Val Loss: 0.0400 | Val AUC: 0.8551 | Val PRC: 0.7167
Validation improved → Saving model (Score=0.7167) to params/phla_model_fold4.pt

[Fold 4] Loading best model from params/phla_model_fold4.pt...
✓ Training completed for Fold 4!

✓ All 5 folds training completed!

Cross-Validation Summary:
----------------------------------------------------------------------
Fold 0: Best Val AUC = 0.8385 (Epoch 10)
Fold 1: Best Val AUC = 0.8678 (Epoch 10)
Fold 2: Best Val AUC = 0.7208 (Epoch 9)
Fold 3: Best Val AUC = 0.8098 (Epoch 8)
Fold 4: Best Val AUC = 0.8551 (Epoch 10)
----------------------------------------------------------------------
Mean Val AUC: 0.8184 ± 0.0525



# Predict

In [10]:
from main import StriMap_pHLA
from utils import load_test_data
import pandas as pd

# Load test data
df_test = pd.read_csv("examples/phla_test_set_example.csv")

if 'label' not in df_test.columns:
    df_test['label'] = 0  # Dummy label column for compatibility

# Standardize HLA fields and map alleles
df_test = load_test_data(
    df_test=df_test,
    hla_dict_path="HLA_dict.npy",
)

# Initialize StriMap with a trained checkpoint
strimap = StriMap_pHLA(
    device="cuda:0",  # or "cpu"
    model_save_path=f"params/phla_model.pt", # Path to trained model
    cache_dir="cache", # Cache directory for embeddings
)

# Prepare embeddings (cached for faster inference)
strimap.prepare_embeddings(
    df_test,
    force_recompute=False,
)

# Run prediction
# y_prob_test, _ = strimap.predict(df_test, use_kfold=True, num_folds=5)

# Run evaluation if ground truth labels are available
y_prob_test, _ = strimap.evaluate(df_test, use_kfold=True, num_folds=5)

df_test["predicted_score"] = y_prob_test
print(df_test.head())

Processing test data...
✓ Test set: 416 samples
Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 416 unique peptides
  - 97 unique HLAs



Phys encoding: 100%|██████████| 2/2 [00:00<00:00, 977.35it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 266.17it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt





[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


[ESM2] No new sequences for hla, using existing cache
✓ All embeddings prepared!
  - Phys: 416 peptides, 97 HLAs
  - ESM2: 2074 peptides, 114 HLAs
  - Struct: 97 HLAs


Ensemble prediction using 5 models...
Ensemble method: mean
Loading model from params/phla_model_fold0.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.67it/s]


Loading model from params/phla_model_fold1.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.57it/s]


Loading model from params/phla_model_fold2.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s]


Loading model from params/phla_model_fold3.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.62it/s]


Loading model from params/phla_model_fold4.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.59it/s]

✓ Ensemble prediction completed using 5 models

Evaluation Results [5-Fold Ensemble (mean)]
tn = 331, fp = 16, fn = 41, tp = 28
y_pred: 0 = 372 | 1 = 44
y_true: 0 = 347 | 1 = 69
AUC: 0.8539 | PRC: 0.6072 | ACC: 0.8630 | MCC: 0.4350 | F1: 0.4956
Precision: 0.6364 | Recall: 0.4058

           HLA      peptide  label  \
0  HLA-B*27:08  GKPAETIRIGD      0   
1  HLA-B*40:01   TEYEEAQDAI      0   
2  HLA-A*68:02   NAENEFVTIK      1   
3  HLA-C*07:01  YKTDVEQIKIN      0   
4  HLA-C*07:01    NGICIYFSR      0   

                                            HLA_full  predicted_score  
0  SHSMRYFHTSVSRPGRGEPRFITVGYVDDTLFVRFDSDAASPREEP...         0.264067  
1  SHSMRYFHTAMSRPGRGEPRFITVGYVDDTLFVRFDSDATSPRKEP...         0.394951  
2  SHSMRYFYTSMSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...         0.461650  
3  SHSMRYFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEP...         0.308959  
4  SHSMRYFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEP...         0.317483  



