# Train pHLA predictor with a K-fold strategy and make predictions on a test set

In [1]:
import pandas as pd
from main import StriMap_pHLA
from utils import load_train_data
from sklearn.model_selection import train_test_split

### Load data

We provided only a small subset here for demonstration purposes. For full dataset, please refer to the [Zenodo repository](https://zenodo.org/records/18002170)

In [2]:
df = pd.read_csv("examples/phla_train_set_example.csv")

### 5-fold train/val splits

In [3]:
num_fold = 5
train_folds, val_folds = zip(*[
    train_test_split(
        df, test_size=0.1, random_state=i, stratify=df["label"]
    )
    for i in range(num_fold)
])

train_folds = [d.reset_index(drop=True) for d in train_folds]
val_folds   = [d.reset_index(drop=True) for d in val_folds]

### HLA full sequence mapping

In [4]:
train_folds, val_folds = load_train_data(
    df_train_list=list(train_folds),
    df_val_list=list(val_folds),
    hla_dict_path="HLA_dict.npy",
)

Loading training and validation data...


### Initialize StriMap pHLA model

ESMFold will be automatically downloaded when running the code below for the first time.

It will take a while to run ESMFold. If **cache_save=True**, the computed embeddings will be saved to disk for faster future loading. It will facilitate faster training and inference when the dataset is large.

`StriMap_pHLA` class initializes

| Parameter | Default | Description |
|----------|---------|-------------|
| `device` | `"cuda:0"` | Device for computation (e.g., `"cuda:0"` or `"cpu"`) |
| `model_save_path` | `"params/phla_model.pt"` | Path to save the best-performing pHLA model |
| `pep_dim` | `256` | Peptide embedding dimension |
| `hla_dim` | `256` | HLA embedding dimension |
| `bilinear_dim` | `256` | Dimension of the bilinear attention layer |
| `loss_fn` | `"focal"` | Loss function to use (`"bce"` or `"focal"`) |
| `alpha` | `0.5` | Alpha parameter for focal loss |
| `gamma` | `2.0` | Gamma parameter for focal loss |
| `esm2_layer` | `33` | ESM2 layer index used to extract sequence embeddings |
| `batch_size` | `256` | Batch size for embedding computation and model training |
| `esmfold_cache_dir` | `"esm_cache"` | Cache directory for ESMFold outputs |
| `cache_dir` | `"cache"` | General cache directory for embeddings and intermediate results |
| `cache_save` | `True` | Whether to save computed embeddings to cache |
| `seed` | `1` | Random seed for reproducibility |
| `pos_weights` | `None` | Positive class weight for handling class imbalance |
| `neg_ratio` | `None` | Negative sampling ratio (if applicable) |

In [None]:
strimap = StriMap_pHLA(
    device="cuda:0",
    model_save_path="params/phla_model.pt",
    cache_dir="cache",
)

# Prepare embeddings (cached for faster training)
strimap.prepare_embeddings(pd.concat([*train_folds, *val_folds], ignore_index=True))

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 1659 unique peptides
  - 113 unique HLAs



Phys encoding: 100%|██████████| 7/7 [00:00<00:00, 160.43it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 200.42it/s]


[ESM2] No existing cache for pep, will create new.
[ESM2] Found 1659 new sequences → computing embeddings...


ESM2 update (pep): 100%|██████████| 7/7 [00:03<00:00,  2.27it/s]


[ESM2] Updating cache with new sequences
[ESM2] No existing cache for hla, will create new.
[ESM2] Found 113 new sequences → computing embeddings...


ESM2 update (hla): 100%|██████████| 1/1 [00:02<00:00,  2.06s/it]


[ESM2] Updating cache with new sequences


INFO:model:/ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt not found or re_embed=True, generating...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 113, valid seqs: 113, unique: 113
ESMfold Predicting structure...: 100%|██████████| 113/113 [03:55<00:00,  2.09s/it]
100%|██████████| 113/113 [03:55<00:00,  2.09s/it]
INFO:model:[DONE] OK: 113, Failed: 0
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ All embeddings prepared!
  - Phys: 1659 peptides, 113 HLAs
  - ESM2: 1659 peptides, 113 HLAs
  - Struct: 113 HLAs



### K-fold training

Train K-fold Cross-validation using `train_kfold`

| Parameter | Default | Description |
|----------|---------|-------------|
| `train_folds` | *(required)* | List of `(train_df, val_df)` tuples, one per fold |
| `df_test` | `None` | Optional test dataset for evaluation after training |
| `epochs` | `100` | Number of training epochs per fold |
| `batch_size` | `256` | Batch size used during training |
| `lr` | `1e-4` | Learning rate |
| `patience` | `5` | Early stopping patience |
| `num_workers` | `8` | Number of workers for data loading |

**Returns:**

- A list of training histories, one for each fold

In [6]:
all_history = strimap.train_kfold(
    train_folds=list(zip(train_folds, val_folds)),
    epochs=10, # Reduced for demonstration; use 100 for actual training
)


Starting 5-Fold Cross-Validation Training

Training Fold 1/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 0]...


                                                                                

[Fold 0] Epoch [1/10] | Train Loss: 0.0748 | Val Loss: 0.0700 | Val AUC: 0.7842 | Val PRC: 0.4359
Validation improved → Saving model (Score=0.4359) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [2/10] | Train Loss: 0.0672 | Val Loss: 0.0649 | Val AUC: 0.7394 | Val PRC: 0.5188
Validation improved → Saving model (Score=0.5188) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [3/10] | Train Loss: 0.0595 | Val Loss: 0.0565 | Val AUC: 0.7692 | Val PRC: 0.5509
Validation improved → Saving model (Score=0.5509) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [4/10] | Train Loss: 0.0550 | Val Loss: 0.0519 | Val AUC: 0.7738 | Val PRC: 0.6305
Validation improved → Saving model (Score=0.6305) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [5/10] | Train Loss: 0.0516 | Val Loss: 0.0515 | Val AUC: 0.7782 | Val PRC: 0.6944
Validation improved → Saving model (Score=0.6944) to params/phla_model_fold0.pt


                                                                                

[Fold 0] Epoch [6/10] | Train Loss: 0.0494 | Val Loss: 0.0517 | Val AUC: 0.7655 | Val PRC: 0.6585
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 0] Epoch [7/10] | Train Loss: 0.0476 | Val Loss: 0.0509 | Val AUC: 0.7712 | Val PRC: 0.6457
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 0] Epoch [8/10] | Train Loss: 0.0463 | Val Loss: 0.0462 | Val AUC: 0.7953 | Val PRC: 0.6840
EarlyStopping counter: 3 out of 5


                                                                                

[Fold 0] Epoch [9/10] | Train Loss: 0.0447 | Val Loss: 0.0448 | Val AUC: 0.8240 | Val PRC: 0.6765
EarlyStopping counter: 4 out of 5


                                                                                

[Fold 0] Epoch [10/10] | Train Loss: 0.0433 | Val Loss: 0.0471 | Val AUC: 0.8018 | Val PRC: 0.6593
EarlyStopping counter: 5 out of 5

[Fold 0] Early stopping triggered at epoch 10!

[Fold 0] Loading best model from params/phla_model_fold0.pt...
✓ Training completed for Fold 0!

Training Fold 2/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 1]...


                                                                                

[Fold 1] Epoch [1/10] | Train Loss: 0.0866 | Val Loss: 0.0727 | Val AUC: 0.7184 | Val PRC: 0.3746
Validation improved → Saving model (Score=0.3746) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [2/10] | Train Loss: 0.0709 | Val Loss: 0.0623 | Val AUC: 0.7808 | Val PRC: 0.4953
Validation improved → Saving model (Score=0.4953) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [3/10] | Train Loss: 0.0606 | Val Loss: 0.0557 | Val AUC: 0.7826 | Val PRC: 0.4405
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 1] Epoch [4/10] | Train Loss: 0.0553 | Val Loss: 0.0500 | Val AUC: 0.7935 | Val PRC: 0.5574
Validation improved → Saving model (Score=0.5574) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [5/10] | Train Loss: 0.0505 | Val Loss: 0.0466 | Val AUC: 0.8175 | Val PRC: 0.6739
Validation improved → Saving model (Score=0.6739) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [6/10] | Train Loss: 0.0471 | Val Loss: 0.0434 | Val AUC: 0.8439 | Val PRC: 0.6896
Validation improved → Saving model (Score=0.6896) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [7/10] | Train Loss: 0.0468 | Val Loss: 0.0426 | Val AUC: 0.8333 | Val PRC: 0.6874
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 1] Epoch [8/10] | Train Loss: 0.0447 | Val Loss: 0.0422 | Val AUC: 0.8463 | Val PRC: 0.7122
Validation improved → Saving model (Score=0.7122) to params/phla_model_fold1.pt


                                                                                

[Fold 1] Epoch [9/10] | Train Loss: 0.0424 | Val Loss: 0.0418 | Val AUC: 0.8383 | Val PRC: 0.6964
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 1] Epoch [10/10] | Train Loss: 0.0416 | Val Loss: 0.0395 | Val AUC: 0.8587 | Val PRC: 0.7444
Validation improved → Saving model (Score=0.7444) to params/phla_model_fold1.pt

[Fold 1] Loading best model from params/phla_model_fold1.pt...
✓ Training completed for Fold 1!

Training Fold 3/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 2]...


                                                                                

[Fold 2] Epoch [1/10] | Train Loss: 0.0713 | Val Loss: 0.0819 | Val AUC: 0.6426 | Val PRC: 0.2984
Validation improved → Saving model (Score=0.2984) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [2/10] | Train Loss: 0.0626 | Val Loss: 0.0683 | Val AUC: 0.6654 | Val PRC: 0.3426
Validation improved → Saving model (Score=0.3426) to params/phla_model_fold2.pt


                                                                                

[Fold 2] Epoch [3/10] | Train Loss: 0.0561 | Val Loss: 0.0639 | Val AUC: 0.6573 | Val PRC: 0.3318
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 2] Epoch [4/10] | Train Loss: 0.0535 | Val Loss: 0.0596 | Val AUC: 0.6607 | Val PRC: 0.3221
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 2] Epoch [5/10] | Train Loss: 0.0498 | Val Loss: 0.0597 | Val AUC: 0.6607 | Val PRC: 0.3008
EarlyStopping counter: 3 out of 5


                                                                                

[Fold 2] Epoch [6/10] | Train Loss: 0.0463 | Val Loss: 0.0627 | Val AUC: 0.6511 | Val PRC: 0.3270
EarlyStopping counter: 4 out of 5


                                                                                

[Fold 2] Epoch [7/10] | Train Loss: 0.0434 | Val Loss: 0.0642 | Val AUC: 0.6749 | Val PRC: 0.3322
EarlyStopping counter: 5 out of 5

[Fold 2] Early stopping triggered at epoch 7!

[Fold 2] Loading best model from params/phla_model_fold2.pt...
✓ Training completed for Fold 2!

Training Fold 4/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 3]...


                                                                                

[Fold 3] Epoch [1/10] | Train Loss: 0.0684 | Val Loss: 0.0684 | Val AUC: 0.7161 | Val PRC: 0.4046
Validation improved → Saving model (Score=0.4046) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [2/10] | Train Loss: 0.0632 | Val Loss: 0.0599 | Val AUC: 0.7730 | Val PRC: 0.3882
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [3/10] | Train Loss: 0.0571 | Val Loss: 0.0611 | Val AUC: 0.7585 | Val PRC: 0.3782
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 3] Epoch [4/10] | Train Loss: 0.0530 | Val Loss: 0.0532 | Val AUC: 0.7736 | Val PRC: 0.4631
Validation improved → Saving model (Score=0.4631) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [5/10] | Train Loss: 0.0516 | Val Loss: 0.0491 | Val AUC: 0.7932 | Val PRC: 0.4458
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [6/10] | Train Loss: 0.0481 | Val Loss: 0.0494 | Val AUC: 0.7927 | Val PRC: 0.4846
Validation improved → Saving model (Score=0.4846) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [7/10] | Train Loss: 0.0463 | Val Loss: 0.0501 | Val AUC: 0.7772 | Val PRC: 0.4288
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [8/10] | Train Loss: 0.0433 | Val Loss: 0.0516 | Val AUC: 0.7847 | Val PRC: 0.4984
Validation improved → Saving model (Score=0.4984) to params/phla_model_fold3.pt


                                                                                

[Fold 3] Epoch [9/10] | Train Loss: 0.0413 | Val Loss: 0.0518 | Val AUC: 0.7736 | Val PRC: 0.4714
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 3] Epoch [10/10] | Train Loss: 0.0410 | Val Loss: 0.0535 | Val AUC: 0.7699 | Val PRC: 0.4594
EarlyStopping counter: 2 out of 5

[Fold 3] Loading best model from params/phla_model_fold3.pt...
✓ Training completed for Fold 3!

Training Fold 5/5
Train: 1494 samples | Val: 166 samples
Creating datasets...

Starting training for 10 epochs [Fold 4]...


                                                                                

[Fold 4] Epoch [1/10] | Train Loss: 0.0950 | Val Loss: 0.0663 | Val AUC: 0.7487 | Val PRC: 0.4224
Validation improved → Saving model (Score=0.4224) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [2/10] | Train Loss: 0.0813 | Val Loss: 0.0622 | Val AUC: 0.7909 | Val PRC: 0.5390
Validation improved → Saving model (Score=0.5390) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [3/10] | Train Loss: 0.0691 | Val Loss: 0.0575 | Val AUC: 0.7686 | Val PRC: 0.5470
Validation improved → Saving model (Score=0.5470) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [4/10] | Train Loss: 0.0606 | Val Loss: 0.0535 | Val AUC: 0.7759 | Val PRC: 0.5324
EarlyStopping counter: 1 out of 5


                                                                                

[Fold 4] Epoch [5/10] | Train Loss: 0.0544 | Val Loss: 0.0488 | Val AUC: 0.8113 | Val PRC: 0.5460
EarlyStopping counter: 2 out of 5


                                                                                

[Fold 4] Epoch [6/10] | Train Loss: 0.0511 | Val Loss: 0.0462 | Val AUC: 0.8248 | Val PRC: 0.6159
Validation improved → Saving model (Score=0.6159) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [7/10] | Train Loss: 0.0478 | Val Loss: 0.0442 | Val AUC: 0.8318 | Val PRC: 0.6173
Validation improved → Saving model (Score=0.6173) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [8/10] | Train Loss: 0.0449 | Val Loss: 0.0432 | Val AUC: 0.8450 | Val PRC: 0.6228
Validation improved → Saving model (Score=0.6228) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [9/10] | Train Loss: 0.0439 | Val Loss: 0.0456 | Val AUC: 0.8315 | Val PRC: 0.6352
Validation improved → Saving model (Score=0.6352) to params/phla_model_fold4.pt


                                                                                

[Fold 4] Epoch [10/10] | Train Loss: 0.0422 | Val Loss: 0.0401 | Val AUC: 0.8675 | Val PRC: 0.6911
Validation improved → Saving model (Score=0.6911) to params/phla_model_fold4.pt

[Fold 4] Loading best model from params/phla_model_fold4.pt...
✓ Training completed for Fold 4!

✓ All 5 folds training completed!

Cross-Validation Summary:
----------------------------------------------------------------------
Fold 0: Best Val AUC = 0.8240 (Epoch 9)
Fold 1: Best Val AUC = 0.8587 (Epoch 10)
Fold 2: Best Val AUC = 0.6749 (Epoch 7)
Fold 3: Best Val AUC = 0.7932 (Epoch 5)
Fold 4: Best Val AUC = 0.8675 (Epoch 10)
----------------------------------------------------------------------
Mean Val AUC: 0.8037 ± 0.0696



# Predict on the test set

For final models evaluated, please refer to the phla_params.zip at [Zenodo repository](https://zenodo.org/records/18002170)

Run inference using a trained StriMap pHLA model by `predict` function.

| Parameter | Default | Description |
|----------|---------|-------------|
| `df` | *(required)* | Input `DataFrame` containing `peptide` and `HLA_full` columns |
| `batch_size` | `256` | Batch size used during inference |
| `return_probs` | `True` | If `True`, return predicted probabilities; otherwise return binary predictions |
| `return_attn` | `False` | Whether to return attention maps from the model |
| `use_kfold` | `False` | Whether to use an ensemble of K-fold models for prediction |
| `num_folds` | `None` | Number of folds used in training (required if `use_kfold=True`) |
| `ensemble_method` | `"mean"` | Method to aggregate K-fold predictions (`"mean"` or `"median"`) |
| `num_workers` | `8` | Number of workers for data loading |

**Returns:**

- An array of predictions (probabilities or binary labels, depending on `return_probs`)

In [7]:
from main import StriMap_pHLA
from utils import load_test_data
import pandas as pd

# Load test data
df_test = pd.read_csv("examples/phla_test_set_example.csv")

if 'label' not in df_test.columns:
    df_test['label'] = 0  # Dummy label column for compatibility

# Standardize HLA fields and map alleles
df_test = load_test_data(
    df_test=df_test,
    hla_dict_path="HLA_dict.npy",
)

# Initialize StriMap with a trained checkpoint
strimap = StriMap_pHLA(
    device="cuda:0",  # or "cpu"
    model_save_path=f"params/phla_model.pt", # Path to trained model
    cache_dir="cache", # Cache directory for embeddings
)

# Prepare embeddings (cached for faster inference)
strimap.prepare_embeddings(
    df_test,
    force_recompute=False,
)

# Run prediction
# y_prob_test, _ = strimap.predict(df_test, use_kfold=True, num_folds=5)

# Run evaluation if ground truth labels are available
y_prob_test, _ = strimap.evaluate(df_test, use_kfold=True, num_folds=5)

df_test["predicted_score"] = y_prob_test
print(df_test.head())

Processing test data...
✓ Test set: 416 samples
Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 416 unique peptides
  - 97 unique HLAs



Phys encoding: 100%|██████████| 2/2 [00:00<00:00, 856.85it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 268.68it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt





[ESM2] Found 415 new sequences → computing embeddings...


ESM2 update (pep): 100%|██████████| 2/2 [00:00<00:00,  5.49it/s]


[ESM2] Updating cache with new sequences
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] Found 1 new sequences → computing embeddings...


ESM2 update (hla): 100%|██████████| 1/1 [00:00<00:00, 32.87it/s]

[ESM2] Updating cache with new sequences



INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Found new hla sequences, embedding...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 1, valid seqs: 1, unique: 1
ESMfold Predicting structure...: 100%|██████████| 1/1 [00:02<00:00,  2.04s/it]
100%|██████████| 1/1 [00:02<00:00,  2.04s/it]
INFO:model:[DONE] OK: 1, Failed: 0
INFO:model:Updated and saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt


✓ All embeddings prepared!
  - Phys: 416 peptides, 97 HLAs
  - ESM2: 2074 peptides, 114 HLAs
  - Struct: 97 HLAs


Ensemble prediction using 5 models...
Ensemble method: mean
Loading model from params/phla_model_fold0.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.72it/s]


Loading model from params/phla_model_fold1.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.60it/s]


Loading model from params/phla_model_fold2.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.63it/s]


Loading model from params/phla_model_fold3.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.65it/s]


Loading model from params/phla_model_fold4.pt...


Predicting: 100%|██████████| 2/2 [00:01<00:00,  1.66it/s]

✓ Ensemble prediction completed using 5 models

Evaluation Results [5-Fold Ensemble (mean)]
tn = 328, fp = 19, fn = 37, tp = 32
y_pred: 0 = 365 | 1 = 51
y_true: 0 = 347 | 1 = 69
AUC: 0.8423 | PRC: 0.5929 | ACC: 0.8654 | MCC: 0.4639 | F1: 0.5333
Precision: 0.6275 | Recall: 0.4638

           HLA      peptide  label  \
0  HLA-B*27:08  GKPAETIRIGD      0   
1  HLA-B*40:01   TEYEEAQDAI      0   
2  HLA-A*68:02   NAENEFVTIK      1   
3  HLA-C*07:01  YKTDVEQIKIN      0   
4  HLA-C*07:01    NGICIYFSR      0   

                                            HLA_full  predicted_score  
0  SHSMRYFHTSVSRPGRGEPRFITVGYVDDTLFVRFDSDAASPREEP...         0.303836  
1  SHSMRYFHTAMSRPGRGEPRFITVGYVDDTLFVRFDSDATSPRKEP...         0.425987  
2  SHSMRYFYTSMSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...         0.428319  
3  SHSMRYFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEP...         0.349639  
4  SHSMRYFDTAVSRPGRGEPRFISVGYVDDTQFVRFDSDAASPRGEP...         0.377198  



