# Train

In [None]:
from main import StriMap_pHLA, load_train_data
import pandas as pd
from sklearn.model_selection import train_test_split

### Load data

In [None]:
df = pd.read_csv("path/to/train_set.csv")

### 5-fold train/val splits

In [None]:
num_fold = 5
train_folds, val_folds = zip(*[
    train_test_split(
        df, test_size=0.1, random_state=i, stratify=df["label"]
    )
    for i in range(num_fold)
])

train_folds = [d.reset_index(drop=True) for d in train_folds]
val_folds   = [d.reset_index(drop=True) for d in val_folds]

### HLA full sequence mapping

In [None]:
train_folds, val_folds = load_train_data(
    df_train_list=list(train_folds),
    df_val_list=list(val_folds),
    hla_dict_path="HLA_dict.npy",
)

### Initialize StriMap pHLA model

In [None]:
strimap = StriMap_pHLA(
    device="cuda:0",
    model_save_path="params/phla_model.pt",
    cache_dir="cache",
)

# Prepare embeddings (cached for faster training)
strimap.prepare_embeddings(
    pd.concat([*train_folds, *val_folds], ignore_index=True), 
    force_recompute=False,
)

# K-fold training
all_history = strimap.train_kfold(
    train_folds=list(zip(train_folds, val_folds)),
    epochs=100,
    num_workers=4,
)

### K-fold training

In [None]:
all_history = strimap.train_kfold(
    train_folds=list(zip(train_folds, val_folds)),
    epochs=100,
    num_workers=4,
)

# Predict

In [None]:
from main import StriMap_pHLA, load_test_data
import pandas as pd

# Load test data
df_test = pd.read_csv("path/to/test_set.csv")

df_test['label'] = 0  # Dummy label column for compatibility

# Standardize HLA fields and map alleles
df_test = load_test_data(
    df_test=df_test,
    hla_dict_path="HLA_dict.npy",
)

# Initialize StriMap with a trained checkpoint
strimap = StriMap_pHLA(
    device="cuda:0",  # or "cpu"
    model_save_path=f"params/phla_model.pt", # Path to trained model
    cache_dir="cache", # Cache directory for embeddings
)

# Prepare embeddings (cached for faster inference)
strimap.prepare_embeddings(
    df_test,
    force_recompute=False,
)

# Run prediction / evaluation
y_prob_test, _ = strimap.predict(
    df_test,
    use_kfold=True,
    num_folds=5,
)
df_test["predicted_score"] = y_prob_test
df_test.to_csv("test_with_predictions.csv", index=False)
print(df_test.head())