# Train TCR-pHLA predictor with a K-fold strategy and make predictions on a test set

In [1]:
import pandas as pd
import numpy as np
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import negative_sampling_phla, preprocess_input_data
from sklearn.model_selection import StratifiedKFold

In [2]:
import pandas as pd
df = pd.read_csv('examples/tcrphla_train_set_example.csv')
df = preprocess_input_data(df)
df

Unnamed: 0,cdr3a,cdr3b,Va,Ja,Vb,Jb,peptide,HLA,label,tcra,tcrb,cdr3a_start,cdr3a_end,cdr3b_start,cdr3b_end,HLA_clean,HLA_full
0,CAVVRSGTYKYIF,CASSLTGTGALYEQYF,TRAV8-3*01,TRAJ40*01,TRBV7-9*01,TRBJ2-7*01,GILGFVFTL,HLA-A*02:01,0,AQSVTQPDIHITVSEGASLELRCNYSYGATPYLFWYVQSPGQGLQL...,DTGVSQNPRHKITKRGQNVTFRCDPISEHNRLYWYRQTLGQGPEFL...,90,103,91,107,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
1,CAVKDGGSQGNLIF,CASSFDSQLYEQYF,TRAV21*01,TRAJ42*01,TRBV7-9*03,TRBJ2-7*01,GILGFVFTL,HLA-A*02:01,0,KQEVTQIPAALSVPEGENLVLNCSFTDSAIYNLQWFRQDPGKGLTS...,DTGVSQDPRHKITKRGQNVTFRCDPISEHNRLYWYRQTLGQGPEFL...,89,103,91,105,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
2,CAFMEYGNKLVF,CASSFLGVTYGYTF,TRAV38-1*01,TRAJ47*01,TRBV13*01,TRBJ1-2*01,CINGVCWTV,HLA-A*02:01,1,AQTVTQSQPEMSVQEAETVTLSCTYDTSENNYYLFWYKQPPSRQMI...,AAGVIQSPRHLIKEKRETATLKCYPIPRHDTVYWYQQGPGQDPQFL...,91,103,90,104,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
3,CAVREDNAGNMLTF,CASGPLLLMTNEQFF,TRAV1-2*01,TRAJ39*01,TRBV12-4*01,TRBJ2-1*01,TTDPSFLGRY,HLA-A*01:01,0,GQNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFL...,DAGVIQSPRHEVTEMGQEVTLRCKPISGHDYLFWYRQTMMRGLELL...,87,101,91,106,A*01:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEP...
4,CAVRDLNARLMF,CASSSVNEQYF,TRAV3*01,TRAJ31*01,TRBV12-3*01,TRBJ2-7*01,CINGVCWTV,HLA-A*02:01,0,AQSVAQPEDQVNVAEGNPLTVKCTYSVSGNPYLFWYVQYPNRGLQF...,DAGVIQSPRHEVTEMGQEVTLRCKPISGHNSLFWYRQTMMRGLELL...,90,102,91,102,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
289,CAVIVAAGNKLTF,CASSQRGGWETQYF,TRAV12-2*01,TRAJ17*01,TRBV27*01,TRBJ2-5*01,NQKLIANQF,HLA-B*15:01,0,QKEVEQNSGPLSVPEGAIASLNCTYSDRGSQSFFWYRQYSGKSPEL...,EAQVTQNPRYLITVTGKKLTVTCSQNMNHEYMSWYRQDPGLGLRQI...,88,101,90,104,B*15:01,SHSMRYFYTAMSRPGRGEPRFIAVGYVDDTQFVRFDSDAASPRMAP...
290,CAVVDSNYQLIW,CASSYTSGGVETQYF,TRAV1-2*01,TRAJ33*01,TRBV6-6*01,TRBJ2-5*01,GILGFVFTL,HLA-A*02:01,0,GQNIDQPTEMTATEGAIVQINCTYQTSGFNGLFWYQQHAGEAPTFL...,NAGVTQTPKFRILKIGQSMTLQCTQDMNHNYMYWYRQDPGMGLKLI...,87,99,90,105,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
291,CAVSEPSGAGSYQLTF,CAISRGNEQYF,TRAV8-6*01,TRAJ28*01,TRBV10-3*01,TRBJ2-7*01,GILGFVFTL,HLA-A*02:01,1,AQSVTQLDSQVPVFEEAPVELRCNYSSSVSVYLFWYVQYPNQGLQL...,DAGITQSPRHKVTETGTPVTLRCHQTENHRYMYWYRQDPGHGLRLI...,90,106,90,101,A*02:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQRMEP...
292,CAARGYPGSYQLTF,CSARESLRDEKLFF,TRAV29/DV5*01,TRAJ28*01,TRBV20-1*01,TRBJ1-4*01,ATDALMTGF,HLA-A*01:01,0,DQQVKQNSPSLSVQEGRISILNCDYTNSMFDYFLWYKKYPAEGPTF...,GAVVSQHPSWVICKSGTSVKIECRSLDFQATTMFWYRQFPKQSLML...,89,103,93,107,A*01:01,SHSMRYFFTSVSRPGRGEPRFIAVGYVDDTQFVRFDSDAASQKMEP...


In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
df_train_folds, df_val_folds = [], []

phla_fold = 5 # number of folds used in pHLA model training
tcrphla_fold = 5 # number of folds used in TCR-pHLA model training

skf = StratifiedKFold(n_splits=tcrphla_fold, shuffle=True, random_state=42)

df_train_folds = []
df_val_folds = []

for i, (train_idx, val_idx) in enumerate(skf.split(df, df["label"])):
    
    df_train_fold = df.iloc[train_idx].reset_index(drop=True)
    df_val_fold   = df.iloc[val_idx].reset_index(drop=True)

    df_val_fold_neg = negative_sampling_phla(df_val_fold, random_state=i)
    df_val_fold = pd.concat([df_val_fold, df_val_fold_neg], axis=0).reset_index(drop=True)

    df_train_folds.append(df_train_fold)
    df_val_folds.append(df_val_fold)

In [4]:
pep_hla_system = StriMap_pHLA(device=device, cache_dir='cache')
pep_hla_system.prepare_embeddings(df.reset_index(drop=True))

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 17 unique peptides
  - 6 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 53.99it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1497.97it/s]

[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt





[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt


INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/cache/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/hla_coord_dict.pt
INFO:model:No new hla sequences found


[ESM2] No new sequences for hla, using existing cache
✓ All embeddings prepared!
  - Phys: 17 peptides, 6 HLAs
  - ESM2: 2090 peptides, 114 HLAs
  - Struct: 6 HLAs



In [None]:
tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    pep_hla_params=[f'params/phla_model_fold{i}.pt' for i in range(phla_fold)], # load trained pHLA model parameters
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
    resample_negatives=True,
    cache_save=True
)
tcr_phla_system.prepare_embeddings(df.reset_index(drop=True))

all_histories = tcr_phla_system.train_kfold(
    train_folds=[(df_train_fold, df_val_fold) for df_train_fold, df_val_fold in zip(df_train_folds, df_val_folds)],
    epochs=10,
    batch_size=2, # make sure batch size is smaller than df_train_fold and df_val_fold sizes
)

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 285 | TCRβ: 288 | peptides: 17 | HLAs: 6



Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 247.53it/s]
Phys encoding (TCRpHLA): 100%|██████████| 2/2 [00:00<00:00, 257.61it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2563.76it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 1754.94it/s]


[ESM2] No existing cache for tcra, will create new.
[ESM2] Found 285 new sequences → computing embeddings...


ESM2 update (tcra): 100%|██████████| 3/3 [00:03<00:00,  1.29s/it]


[ESM2] Updating cache with new sequences
[ESM2] No existing cache for tcrb, will create new.
[ESM2] Found 288 new sequences → computing embeddings...


ESM2 update (tcrb): 100%|██████████| 3/3 [00:03<00:00,  1.19s/it]


[ESM2] Updating cache with new sequences
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/cache/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache


INFO:model:/ewsc/cao/StriMap/strimap-tools/cache/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/cache/tcra_coord_dict.pt not found or re_embed=True, generating...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 285, valid seqs: 285, unique: 285
ESMfold Predicting structure...:   8%|▊         | 22/285 [00:23<04:34,  1.04s/it]

In [None]:
import pandas as pd
import torch
from main import StriMap_pHLA, StriMap_TCRpHLA
from utils import preprocess_input_data

df_test = pd.read_csv('examples/tcrphla_test_set_example.csv')
df_test = preprocess_input_data(df_test)

tcrphla_fold = 5
phla_fold = 5

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
pep_hla_system = StriMap_pHLA(device=device, cache_dir='tcrphla_cache/pep_background')
pep_hla_system.prepare_embeddings(df_test)

tcr_phla_system = StriMap_TCRpHLA(
    pep_hla_system=pep_hla_system,
    pep_hla_params=[f'params/phla_model_fold{i}.pt' for i in range(phla_fold)],
    device=device,
    model_save_path=f'params/tcrphla_model.pt',
)

tcr_phla_system.prepare_embeddings(df_test)

# Predict on the test set using K-fold models
# pred, pep_feat, attn_dict = tcr_phla_system.predict(df_test, use_kfold=True, num_folds=tcrphla_fold)

# Alternatively, use evaluate to get performance metrics as well
pred, _ = tcr_phla_system.evaluate(df_test, use_kfold=True, num_folds=tcrphla_fold)
df_test['predicted_score'] = pred
print(df_test.head())

Initializing encoders...
✓ Loaded 20 AAindex features
Initializing binding prediction model...
✓ StriMap initialized on cuda:0

Preparing embeddings for:
  - 16 unique peptides
  - 6 unique HLAs



Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1573.85it/s]
Phys encoding: 100%|██████████| 1/1 [00:00<00:00, 1756.41it/s]


[ESM2] No existing cache for pep, will create new.
[ESM2] Found 16 new sequences → computing embeddings...


ESM2 update (pep): 100%|██████████| 1/1 [00:00<00:00, 25.63it/s]


[ESM2] Updating cache with new sequences
[ESM2] No existing cache for hla, will create new.
[ESM2] Found 6 new sequences → computing embeddings...


ESM2 update (hla): 100%|██████████| 1/1 [00:00<00:00, 13.40it/s]

[ESM2] Updating cache with new sequences



INFO:model:/ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_coord_dict.pt not found or re_embed=True, generating...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 6, valid seqs: 6, unique: 6
ESMfold Predicting structure...: 100%|██████████| 6/6 [00:12<00:00,  2.07s/it]
100%|██████████| 6/6 [00:12<00:00,  2.07s/it]
INFO:model:[DONE] OK: 6, Failed: 0
INFO:model:Saved /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ All embeddings prepared!
  - Phys: 16 peptides, 6 HLAs
  - ESM2: 16 peptides, 6 HLAs
  - Struct: 6 HLAs

✓ StriMap_TCRpHLA initialized on cuda:0

Preparing embeddings:
  - TCRα: 73 | TCRβ: 72 | peptides: 16 | HLAs: 6



Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 398.47it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 416.64it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 2484.78it/s]
Phys encoding (TCRpHLA): 100%|██████████| 1/1 [00:00<00:00, 1639.04it/s]


[ESM2] No existing cache for tcra, will create new.
[ESM2] Found 73 new sequences → computing embeddings...


ESM2 update (tcra): 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]


[ESM2] No existing cache for tcrb, will create new.
[ESM2] Found 72 new sequences → computing embeddings...


ESM2 update (tcrb): 100%|██████████| 1/1 [00:00<00:00,  1.97it/s]
INFO:model:/ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/tcra_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/tcra_coord_dict.pt not found or re_embed=True, generating...


[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/pep_esm2_layer33.pt
[ESM2] No new sequences for pep, using existing cache
[ESM2] Loading cached embeddings from /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_esm2_layer33.pt
[ESM2] No new sequences for hla, using existing cache
Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 73, valid seqs: 73, unique: 73
ESMfold Predicting structure...: 100%|██████████| 73/73 [01:14<00:00,  1.02s/it]
100%|██████████| 73/73 [01:14<00:00,  1.02s/it]
INFO:model:[DONE] OK: 73, Failed: 0
INFO:model:No new tcra sequences found
INFO:model:/ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/tcrb_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/tcrb_coord_dict.pt not found or re_embed=True, generating...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 72, valid seqs: 72, unique: 72
ESMfold Predicting structure...: 100%|██████████| 72/72 [01:14<00:00,  1.03s/it]
100%|██████████| 72/72 [01:14<00:00,  1.03s/it]
INFO:model:[DONE] OK: 72, Failed: 0
INFO:model:No new tcrb sequences found
INFO:model:/ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/pep_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/pep_coord_dict.pt not found or re_embed=True, generating...


Loading ESMFold model facebook/esmfold_v1 on cuda:0... with cache_dir: esm_cache


Some weights of EsmForProteinFolding were not initialized from the model checkpoint at facebook/esmfold_v1 and are newly initialized: ['esm.contact_head.regression.bias', 'esm.contact_head.regression.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:model:Total rows: 16, valid seqs: 16, unique: 16
ESMfold Predicting structure...: 100%|██████████| 16/16 [00:14<00:00,  1.13it/s]
100%|██████████| 16/16 [00:14<00:00,  1.13it/s]
INFO:model:[DONE] OK: 16, Failed: 0
INFO:model:No new pep sequences found
INFO:model:Loading /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_feat_dict.pt and /ewsc/cao/StriMap/strimap-tools/tcrphla_cache/pep_background/hla_coord_dict.pt
INFO:model:No new hla sequences found


✓ Embeddings prepared for TCRα/β, peptide (with ESMFold), and HLA.
✓ Embeddings prepared for TCRα/β, peptide, and HLA.
Preparing peptide-HLA features for prediction set...

Precomputing peptide-HLA features for 16 unique pairs...


pHLA features (batched): 100%|██████████| 1/1 [00:00<00:00,  2.73it/s]


✓ Pretrained peptide-HLA features prepared.

Ensemble prediction using 5 TCR–pHLA models...
Ensemble method: mean
Loading model from params/tcrphla_model_fold0.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.56s/it]


Loading model from params/tcrphla_model_fold1.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:02<00:00,  2.07s/it]


Loading model from params/tcrphla_model_fold2.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.70s/it]


Loading model from params/tcrphla_model_fold3.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.47s/it]


Loading model from params/tcrphla_model_fold4.pt...


Predicting (TCR-pHLA): 100%|██████████| 1/1 [00:01<00:00,  1.51s/it]

✓ Ensemble prediction completed using 5 folds

Evaluation Results [K-Fold Ensemble]
tn=66, fp=0, fn=8, tp=0
AUC=0.5265 | PRC=0.2470 | ACC=0.8919 | MCC=0.0000 | F1=0.0000 | P=0.0000 | R=0.0000






ValueError: not enough values to unpack (expected 3, got 2)