In [None]:
from main import StriMap_pHLA, load_train_data
import pandas as pd
from sklearn.model_selection import train_test_split

# Load data
df = pd.read_csv("path/to/train_set.csv")

# 5-fold train/val splits
num_fold = 5
train_folds, val_folds = zip(*[
    train_test_split(
        df, test_size=0.1, random_state=i, stratify=df["label"]
    )
    for i in range(num_fold)
])

train_folds = [d.reset_index(drop=True) for d in train_folds]
val_folds   = [d.reset_index(drop=True) for d in val_folds]

# HLA full sequence mapping
train_folds, val_folds = load_train_data(
    df_train_list=list(train_folds),
    df_val_list=list(val_folds),
    hla_dict_path="HLA_dict.npy",
)

# Initialize StriMap pHLA model
strimap = StriMap_pHLA(
    device="cuda:0",
    model_save_path="params/phla/best_model.pt",
    cache_dir="cache/phla",
)

# Prepare embeddings (cached for faster training)
strimap.prepare_embeddings(
    pd.concat([*train_folds, *val_folds], ignore_index=True), 
    force_recompute=False,
)

# K-fold training
all_history = strimap.train_kfold(
    train_folds=list(zip(train_folds, val_folds)),
    epochs=100,
    num_workers=4,
)