In [1]:
import inspect
print(inspect.__file__)
import datasets
print("datasets version:", datasets.__version__)


/home/lina.utenova/.conda/envs/enhancer-predict/lib/python3.12/inspect.py
datasets version: 4.4.1


In [3]:
# =========================================
# Multi-task BERT: Early-stopping, P/R/F1, Saving, and Plots
# =========================================
import os, json, math, random, re, gc
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import torch
import torch.nn as nn

from sklearn.metrics import (
    roc_auc_score, average_precision_score, accuracy_score,
    precision_recall_fscore_support
)

from datasets import Dataset
from transformers import (
    BertConfig, BertModel,
    TrainingArguments, Trainer, default_data_collator, set_seed,
    EarlyStoppingCallback
)

!python -m pip install matplotlib
import matplotlib.pyplot as plt


Collecting matplotlib
  Downloading matplotlib-3.10.7-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.5 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.60.1-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl.metadata (112 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.9-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (6.3 kB)
Collecting pillow>=8 (from matplotlib)
  Downloading pillow-12.0.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.8 kB)
Collecting pyparsing>=3 (from matplotlib)
  Downloading pyparsing-3.2.5-py3-none-any.whl.metadata (5.0 kB)
Downloadin

In [9]:
# =========================================
# Multi-task BERT (CSV schema with kid_*, mask_*, tissue_*, label)
# - Early stopping on F1
# - Precision/Recall/F1 metrics
# - Save model + results + plots
# =========================================
import os, json, math, random, re, gc
from typing import List, Tuple
import numpy as np
import pandas as pd

import torch
import torch.nn as nn

from sklearn.metrics import (
    roc_auc_score, average_precision_score, accuracy_score,
    precision_recall_fscore_support
)

import matplotlib.pyplot as plt
from datasets import Dataset
from transformers import (
    BertConfig, BertModel,
    TrainingArguments, Trainer, default_data_collator, set_seed,
    EarlyStoppingCallback, TrainerCallback, TrainerState, TrainerControl
)

# -------------------------
# Paths & constants
# -------------------------
TRAIN_CSV = "train.csv"
VAL_CSV   = "validation.csv"
OUT_DIR   = "model_out"
os.makedirs(OUT_DIR, exist_ok=True)

SEP = ","            # CSV delimiter
SEED = 42

# Sequence layout (no metadata columns)
MAX_LEN = 348                       # kid_0..kid_347 and mask_0..mask_347
ID_COLS   = [f"kid_{i}"  for i in range(MAX_LEN)]
MASK_COLS = [f"mask_{i}" for i in range(MAX_LEN)]

# Model size (compact; adjust as needed)
HIDDEN = 256
LAYERS = 6
HEADS  = 4
INTERM = 1024
LR = 3e-4
EPOCHS_MAX = 30
BATCH_TRAIN = 256
BATCH_EVAL  = 1024
TISSUE_LOSS_WEIGHT = 1.0

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -------------------------
# Reproducibility
# -------------------------
def set_all_seeds(seed=SEED):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
set_all_seeds(SEED)

# -------------------------
# Column detection (robust to label name)
# -------------------------
def read_header(path: str) -> List[str]:
    return pd.read_csv(path, sep=SEP, nrows=0).columns.tolist()

hdr = read_header(TRAIN_CSV)

# Ensure kid_* and mask_* exist
for cols, name in [(ID_COLS, "kid_*"), (MASK_COLS, "mask_*")]:
    missing = [c for c in cols if c not in hdr]
    assert not missing, f"Missing {name} columns: {missing[:5]}..."

# Detect tissue columns by prefix (tissue_*) or anything named 'tissue...'
tissue_cols = [c for c in hdr if c.startswith("tissue_")]
if len(tissue_cols) == 0:
    tissue_cols = [c for c in hdr if "tissue" in c.lower()]
assert len(tissue_cols) > 0, "No tissue columns found (expected tissue_*)."
NUM_TISSUES = len(tissue_cols)

# Detect label column robustly:
# Try common names; else find a single binary column among remaining that isn't a tissue.
LABEL_CANDIDATES = ["enhancer", "enhancer_label", "label", "target"]
label_col = None
for cand in LABEL_CANDIDATES:
    if cand in hdr:
        label_col = cand
        break

if label_col is None:
    # try infer from sample
    candidate_pool = [c for c in hdr if c not in set(ID_COLS) | set(MASK_COLS) | set(tissue_cols)]
    sample = pd.read_csv(TRAIN_CSV, sep=SEP, usecols=candidate_pool, nrows=5000, low_memory=False)
    def is_binary(s):
        v = pd.to_numeric(s, errors="coerce").fillna(0).astype(int)
        u = set(v.unique().tolist())
        return u.issubset({0,1}) and (0 in u) and (1 in u)
    # prefer names containing 'enhanc'
    candidates = [c for c in sample.columns if is_binary(sample[c])]
    enh_like   = [c for c in candidates if "enhanc" in c.lower()]
    label_col  = enh_like[0] if len(enh_like)>0 else (candidates[0] if len(candidates)>0 else None)

assert label_col is not None, "Could not find a binary enhancer label column."
LABEL_COL = label_col
print(f"[INFO] Using LABEL_COL = '{LABEL_COL}', NUM_TISSUES = {NUM_TISSUES}")

# -------------------------
# Load CSVs
# -------------------------
usecols = ID_COLS + MASK_COLS + tissue_cols + [LABEL_COL]
dtype_map = {
    **{c: "uint16" for c in ID_COLS},     # token ids
    **{c: "uint8"  for c in MASK_COLS},   # attention mask
    **{c: "uint8"  for c in tissue_cols}, # multi-label tissues
    LABEL_COL: "uint8",
}
df_tr = pd.read_csv(TRAIN_CSV, sep=SEP, usecols=usecols, dtype=dtype_map)
df_va = pd.read_csv(VAL_CSV,   sep=SEP, usecols=usecols, dtype=dtype_map)
print("Train shape:", df_tr.shape, "Val shape:", df_va.shape)

# -------------------------
# Pack to HF Datasets
# -------------------------
def pack_df(df: pd.DataFrame) -> Dataset:
    arr_ids  = df[ID_COLS].to_numpy(dtype=np.int64)
    arr_mask = df[MASK_COLS].to_numpy(dtype=np.int64)
    arr_lab  = df[LABEL_COL].to_numpy(dtype=np.int64)
    arr_tis  = df[tissue_cols].to_numpy(dtype=np.int64)
    recs = [{
        "input_ids":      arr_ids[i].tolist(),
        "attention_mask": arr_mask[i].tolist(),
        "labels":         int(arr_lab[i]),
        "tissues":        arr_tis[i].tolist(),
    } for i in range(len(df))]
    return Dataset.from_list(recs)

ds_train = pack_df(df_tr)
ds_val   = pack_df(df_va)

# Infer vocab size from kid_* values
def _per_max(batch):
    return {"max_id": [int(np.max(ids)) for ids in batch["input_ids"]]}
tmp = ds_train.map(_per_max, batched=True, batch_size=10_000, desc="inferring vocab")
VOCAB_SIZE = int(max(tmp["max_id"])) + 1
print(f"[INFO] VOCAB_SIZE = {VOCAB_SIZE}")

# -------------------------
# Model
# -------------------------
class MultiTaskBERT(nn.Module):
    def __init__(self, vocab_size, num_tissues):
        super().__init__()
        config = BertConfig(
            vocab_size=vocab_size,
            hidden_size=HIDDEN,
            num_hidden_layers=LAYERS,
            num_attention_heads=HEADS,
            intermediate_size=INTERM,
            max_position_embeddings=MAX_LEN,
            type_vocab_size=1,
            hidden_dropout_prob=0.1,
            attention_probs_dropout_prob=0.1
        )
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)
        self.enhancer_head = nn.Linear(HIDDEN, 2)
        self.tissue_head   = nn.Linear(HIDDEN, num_tissues)

    def masked_mean_pool(self, hidden_states, attention_mask):
        mask = attention_mask.unsqueeze(-1).float()
        summed = (hidden_states * mask).sum(dim=1)
        denom = mask.sum(dim=1).clamp(min=1e-6)
        return summed / denom

    def forward(self, input_ids, attention_mask, labels=None, tissues=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.masked_mean_pool(out.last_hidden_state, attention_mask)
        pooled = self.dropout(pooled)
        logits_enh = self.enhancer_head(pooled)
        logits_tis = self.tissue_head(pooled)
        return {"logits_enh": logits_enh, "logits_tis": logits_tis}

# -------------------------
# Custom Trainer (weighted multitask loss)
# -------------------------
class MultiTaskTrainer(Trainer):
    def __init__(self, class_weights_ce=None, pos_weight_bce=None, **kwargs):
        super().__init__(**kwargs)
        self.class_weights_ce = class_weights_ce
        self.pos_weight_bce   = pos_weight_bce

    # NOTE: accept **kwargs to swallow num_items_in_batch and any future args
    def compute_loss(self, model, inputs, return_outputs: bool = False, **kwargs):
        labels  = inputs.pop("labels")
        tissues = inputs.pop("tissues")
        outputs = model(**inputs)
        logits_enh = outputs["logits_enh"]
        logits_tis = outputs["logits_tis"]

        ce = nn.CrossEntropyLoss(weight=self.class_weights_ce)
        loss_enh = ce(logits_enh, labels)

        enh_mask = (labels == 1)
        if enh_mask.any():
            logits_pos  = logits_tis[enh_mask]
            tissues_pos = tissues[enh_mask].float()
            bce = nn.BCEWithLogitsLoss(pos_weight=self.pos_weight_bce)
            loss_tis = bce(logits_pos, tissues_pos)
        else:
            loss_tis = torch.tensor(0.0, device=logits_tis.device)

        loss = loss_enh + TISSUE_LOSS_WEIGHT * loss_tis
        return (loss, {"logits_enh": logits_enh, "logits_tis": logits_tis}) if return_outputs else loss

# -------------------------
# Metrics (adds Precision/Recall/F1)
# -------------------------
def compute_metrics_fn(eval_pred):
    preds, labels = eval_pred
    if isinstance(preds, dict):
        logits_enh = preds["logits_enh"]
        logits_tis = preds["logits_tis"]
    else:
        logits_enh, logits_tis = preds

    y_enh, y_tis = labels  # due to label_names

    logits_enh = np.asarray(logits_enh)
    logits_tis = np.asarray(logits_tis)
    y_enh      = np.asarray(y_enh)
    y_tis      = np.asarray(y_tis)

    p_enh = (torch.softmax(torch.tensor(logits_enh), dim=-1).numpy()[:, 1])
    y_pred_enh = (p_enh >= 0.5).astype(int)

    enh_prec, enh_rec, enh_f1, _ = precision_recall_fscore_support(
        y_enh, y_pred_enh, average="binary", zero_division=0
    )
    try:
        enh_auprc = average_precision_score(y_enh, p_enh)
        enh_auroc = roc_auc_score(y_enh, p_enh)
    except Exception:
        enh_auprc, enh_auroc = float("nan"), float("nan")
    enh_acc = accuracy_score(y_enh, y_pred_enh)

    mask = (y_enh == 1)
    if mask.sum() > 0:
        y_true_t = y_tis[mask].astype(int)
        y_prob_t = torch.sigmoid(torch.tensor(logits_tis)).numpy()[mask]
        y_pred_t = (y_prob_t >= 0.5).astype(int)

        tis_prec_micro, tis_rec_micro, tis_f1_micro, _ = precision_recall_fscore_support(
            y_true_t, y_pred_t, average="micro", zero_division=0
        )
        tis_prec_macro, tis_rec_macro, tis_f1_macro, _ = precision_recall_fscore_support(
            y_true_t, y_pred_t, average="macro", zero_division=0
        )
        try:
            tis_auprc_micro = average_precision_score(y_true_t.reshape(-1), y_prob_t.reshape(-1))
        except Exception:
            tis_auprc_micro = float("nan")
        per_t = []
        for j in range(y_true_t.shape[1]):
            yj = y_true_t[:, j]; pj = y_prob_t[:, j]
            if (yj.sum() > 0) and (yj.sum() < len(yj)):
                per_t.append(average_precision_score(yj, pj))
        tis_auprc_macro = float(np.mean(per_t)) if per_t else float("nan")
    else:
        tis_prec_micro = tis_rec_micro = tis_f1_micro = 0.0
        tis_prec_macro = tis_rec_macro = tis_f1_macro = 0.0
        tis_auprc_micro = tis_auprc_macro = 0.0

    return {
        "enh_precision": float(enh_prec),
        "enh_recall":    float(enh_rec),
        "enh_f1":        float(enh_f1),
        "enh_auprc":     float(enh_auprc),
        "enh_auroc":     float(enh_auroc),
        "enh_acc":       float(enh_acc),

        "tis_precision_micro": float(tis_prec_micro),
        "tis_recall_micro":    float(tis_rec_micro),
        "tis_f1_micro":        float(tis_f1_micro),
        "tis_precision_macro": float(tis_prec_macro),
        "tis_recall_macro":    float(tis_rec_macro),
        "tis_f1_macro":        float(tis_f1_macro),
        "tis_auprc_micro":     float(tis_auprc_micro),
        "tis_auprc_macro":     float(tis_auprc_macro),
    }

# -------------------------
# Param Drift² callback
# -------------------------
class ParamDriftCallback(TrainerCallback):
    def __init__(self):
        self.prev = None
    def on_epoch_end(self, args, state: TrainerState, control: TrainerControl, **kwargs):
        model = kwargs["model"]
        vec = torch.cat([p.detach().float().flatten().cpu()
                         for p in model.parameters() if p.requires_grad])
        if self.prev is None:
            drift2 = 0.0
        else:
            diff = vec - self.prev
            drift2 = float((diff * diff).sum().item())
        self.prev = vec
        state.log_history.append({"epoch": state.epoch, "param_drift2": drift2})

param_drift_cb = ParamDriftCallback()

# -------------------------
# Class weights
# -------------------------
y_tr = df_tr[LABEL_COL].to_numpy()
n_pos = int((y_tr == 1).sum()); n_neg = int((y_tr == 0).sum())
w_neg = 0.5 * (n_pos + n_neg) / max(n_neg, 1)
w_pos = 0.5 * (n_pos + n_neg) / max(n_pos, 1)
class_weights = torch.tensor([w_neg, w_pos], dtype=torch.float, device=DEVICE)

tis_tr = df_tr[tissue_cols].to_numpy()
enh_mask_tr = (y_tr == 1)
if enh_mask_tr.any():
    pos_counts = tis_tr[enh_mask_tr].sum(axis=0) + 1e-6
    neg_counts = enh_mask_tr.sum() - pos_counts + 1e-6
    pos_weight = torch.tensor(neg_counts / pos_counts, dtype=torch.float, device=DEVICE)
else:
    pos_weight = torch.ones(NUM_TISSUES, dtype=torch.float, device=DEVICE)

# -------------------------
# Trainer setup (early stopping on F1)
# -------------------------
model = MultiTaskBERT(vocab_size=VOCAB_SIZE, num_tissues=NUM_TISSUES).to(DEVICE)
fp16_flag = torch.cuda.is_available()

args = TrainingArguments(
    output_dir=os.path.join(OUT_DIR, "checkpoints"),
    per_device_train_batch_size=BATCH_TRAIN,
    per_device_eval_batch_size=BATCH_EVAL,
    learning_rate=LR,
    num_train_epochs=EPOCHS_MAX,
    eval_strategy="epoch",
    save_strategy="epoch",
    label_names=["labels","tissues"],
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="eval_enh_f1",
    greater_is_better=True,
    fp16=fp16_flag,
    dataloader_num_workers=4,
    dataloader_pin_memory=True,
    dataloader_persistent_workers=False,
    logging_steps=100,
    seed=SEED,
    report_to=[],  # add "tensorboard" to log to TB
)

early_stop_cb = EarlyStoppingCallback(
    early_stopping_patience=3, early_stopping_threshold=0.0
)

trainer = MultiTaskTrainer(
    model=model,
    args=args,
    train_dataset=ds_train,
    eval_dataset=ds_val,
    data_collator=default_data_collator,
    class_weights_ce=class_weights,
    pos_weight_bce=pos_weight,
    tokenizer=None,
    compute_metrics=compute_metrics_fn,
    callbacks=[early_stop_cb, param_drift_cb],
)

trainer.train()
final_eval = trainer.evaluate()

# -------------------------
# Save artifacts (shareable)
# -------------------------
trainer.save_model(os.path.join(OUT_DIR, "best_model"))  # weights + config.json

meta = {
    "VOCAB_SIZE": VOCAB_SIZE,
    "MAX_LEN": MAX_LEN,
    "MODEL_DIMS": {"HIDDEN": HIDDEN, "LAYERS": LAYERS, "HEADS": HEADS, "INTERM": INTERM},
    "TRAIN_ARGS": {
        "LR": LR, "BATCH_TRAIN": BATCH_TRAIN, "BATCH_EVAL": BATCH_EVAL,
        "TISSUE_LOSS_WEIGHT": TISSUE_LOSS_WEIGHT, "SEED": SEED
    },
    "COLUMNS": {
        "LABEL_COL": LABEL_COL,
        "ID_COLS": ID_COLS,
        "MASK_COLS": MASK_COLS,
        "TISSUE_COLS": tissue_cols
    },
    "FILES": {"train": TRAIN_CSV, "validation": VAL_CSV}
}
with open(os.path.join(OUT_DIR, "training_meta.json"), "w") as f:
    json.dump(meta, f, indent=2)

with open(os.path.join(OUT_DIR, "final_eval_metrics.json"), "w") as f:
    json.dump(final_eval, f, indent=2)

# Validation predictions snapshot
pred_out = trainer.predict(ds_val)
preds, labels = pred_out.predictions, pred_out.label_ids
if isinstance(preds, dict):
    logits_enh, logits_tis = preds["logits_enh"], preds["logits_tis"]
else:
    logits_enh, logits_tis = preds
y_enh, y_tis = labels

p_enh = torch.softmax(torch.tensor(logits_enh), dim=-1).numpy()[:,1]
y_pred_enh = (p_enh >= 0.5).astype(int)
pred_df = pd.DataFrame({
    "y_true_enh": y_enh.astype(int),
    "p_enh": p_enh,
    "y_pred_enh": y_pred_enh
})
pred_df.to_csv(os.path.join(OUT_DIR, "validation_preds_enhancer.csv"), index=False)

# -------------------------
# Plots from log history
# -------------------------
hist = trainer.state.log_history

# Loss curves
epochs = []; train_loss = []; eval_loss = []
for h in hist:
    if "loss" in h and "epoch" in h:
        epochs.append(h["epoch"]); train_loss.append(h["loss"])
    if "eval_loss" in h and "epoch" in h:
        eval_loss.append((h["epoch"], h["eval_loss"]))

plt.figure()
if epochs:
    plt.plot(epochs, train_loss, marker="o", label="train_loss")
if eval_loss:
    evx = [e for e,_ in eval_loss]
    evy = [l for _,l in eval_loss]
    plt.plot(evx, evy, marker="o", label="eval_loss")
plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Loss per epoch")
plt.legend(); plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "curve_loss.png")); plt.close()

# Metric curves
def plot_metric(metric_key, title, fname):
    ex = []; ey = []
    for h in hist:
        if metric_key in h and "epoch" in h:
            ex.append(h["epoch"]); ey.append(h[metric_key])
    if ex:
        plt.figure()
        plt.plot(ex, ey, marker="o")
        plt.xlabel("Epoch"); plt.ylabel(metric_key)
        plt.title(title)
        plt.tight_layout()
        plt.savefig(os.path.join(OUT_DIR, fname)); plt.close()

plot_metric("eval_enh_precision", "Enhancer Precision (val)", "curve_enh_precision.png")
plot_metric("eval_enh_recall",    "Enhancer Recall (val)",    "curve_enh_recall.png")
plot_metric("eval_enh_f1",        "Enhancer F1 (val)",        "curve_enh_f1.png")
plot_metric("eval_enh_auprc",     "Enhancer AUPRC (val)",     "curve_enh_auprc.png")
plot_metric("eval_tis_f1_micro",  "Tissues F1 micro (val)",   "curve_tis_f1_micro.png")
plot_metric("eval_tis_f1_macro",  "Tissues F1 macro (val)",   "curve_tis_f1_macro.png")

# Param drift² per epoch
ex = []; ey = []
for h in hist:
    if "param_drift2" in h and "epoch" in h:
        ex.append(h["epoch"]); ey.append(h["param_drift2"])
if ex:
    plt.figure()
    plt.plot(ex, ey, marker="o")
    plt.xlabel("Epoch"); plt.ylabel("Param drift²")
    plt.title("Sum of squared parameter change per epoch")
    plt.tight_layout()
    plt.savefig(os.path.join(OUT_DIR, "curve_param_drift2.png")); plt.close()

print("\nSaved artifacts in:", OUT_DIR)
print(" - best_model/ (HF weights + config)")
print(" - training_meta.json (columns, dims, files)")
print(" - final_eval_metrics.json")
print(" - validation_preds_enhancer.csv")
print(" - curve_*.png (loss, PR/F1, AUPRC, drift²)")


[INFO] Using LABEL_COL = 'enhancer_label', NUM_TISSUES = 9
Train shape: (377626, 706) Val shape: (47203, 706)


inferring vocab:   0%|          | 0/377626 [00:00<?, ? examples/s]

[INFO] VOCAB_SIZE = 65


  super().__init__(**kwargs)


Epoch,Training Loss,Validation Loss,Enh Precision,Enh Recall,Enh F1,Enh Auprc,Enh Auroc,Enh Acc,Tis Precision Micro,Tis Recall Micro,Tis F1 Micro,Tis Precision Macro,Tis Recall Macro,Tis F1 Macro,Tis Auprc Micro,Tis Auprc Macro
1,1.7269,1.733877,0.81659,0.52653,0.640239,0.822766,0.627159,0.55702,0.21494,0.669977,0.325465,0.207761,0.69332,0.307583,0.301053,0.336324
2,1.7039,1.708794,0.810147,0.565753,0.666245,0.823351,0.625633,0.575663,0.242705,0.599517,0.345529,0.235015,0.63087,0.327662,0.32203,0.349423
3,1.6992,1.696402,0.813148,0.597164,0.688618,0.827086,0.635874,0.595704,0.246907,0.621747,0.353452,0.234237,0.643154,0.332468,0.327351,0.348625
4,1.6862,1.700289,0.810172,0.638311,0.714046,0.826708,0.636773,0.61727,0.23904,0.665292,0.35171,0.226685,0.674457,0.330115,0.328564,0.351873
5,1.6791,1.700356,0.818769,0.580694,0.679481,0.830235,0.641098,0.589878,0.278436,0.549202,0.369528,0.265099,0.569098,0.349264,0.334824,0.356161
6,1.6648,1.694719,0.818068,0.574016,0.674649,0.828737,0.639595,0.585535,0.253212,0.559486,0.348638,0.252487,0.60678,0.335221,0.312281,0.357278
7,1.6597,1.693965,0.800943,0.706879,0.750977,0.828307,0.636668,0.649048,0.253481,0.613712,0.358776,0.243216,0.631119,0.338168,0.332902,0.357515
8,1.6673,1.695645,0.798468,0.74313,0.769806,0.829899,0.63918,0.667288,0.238952,0.664074,0.351445,0.226961,0.682782,0.330325,0.322304,0.352885
9,1.6596,1.690291,0.812088,0.645244,0.719116,0.831135,0.641405,0.622651,0.242364,0.611112,0.347078,0.235815,0.644048,0.329802,0.32605,0.358336
10,1.647,1.697982,0.806646,0.684863,0.740783,0.83104,0.6423,0.641188,0.222209,0.683282,0.335357,0.214834,0.711095,0.317175,0.319593,0.354825



Saved artifacts in: model_out
 - best_model/ (HF weights + config)
 - training_meta.json (columns, dims, files)
 - final_eval_metrics.json
 - validation_preds_enhancer.csv
 - curve_*.png (loss, PR/F1, AUPRC, drift²)


In [10]:
# After training (or reload model from the Trainer)
cfg = model.bert.config
cfg.save_pretrained("model_out/best_model")


In [None]:
from transformers import BertModel, BertConfig
import torch, json, pandas as pd
from datasets import Dataset

OUT_DIR = "model_out"
with open(f"{OUT_DIR}/training_meta.json") as f:
    meta = json.load(f)

ID_COLS    = meta["COLUMNS"]["ID_COLS"]
MASK_COLS  = meta["COLUMNS"]["MASK_COLS"]
LABEL_COL  = meta["COLUMNS"]["LABEL_COL"]
tissue_cols= meta["COLUMNS"]["TISSUE_COLS"]
MAX_LEN    = meta["MAX_LEN"]

# Rebuild model and load weights
from transformers import AutoConfig
cfg = BertConfig.from_pretrained(f"{OUT_DIR}/best_model")
class MultiTaskBERT(nn.Module):
    def __init__(self, config, num_tissues):
        super().__init__()
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)
        self.enhancer_head = nn.Linear(config.hidden_size, 2)
        self.tissue_head   = nn.Linear(config.hidden_size, len(tissue_cols))
    def masked_mean_pool(self, hs, am):
        mask = am.unsqueeze(-1).float()
        return (hs*mask).sum(1) / mask.sum(1).clamp(min=1e-6)
    def forward(self, input_ids, attention_mask):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled = self.dropout(self.masked_mean_pool(out.last_hidden_state, attention_mask))
        return {
            "logits_enh": self.enhancer_head(pooled),
            "logits_tis": self.tissue_head(pooled)
        }

model = MultiTaskBERT(cfg, num_tissues=len(tissue_cols))
state = torch.load(f"{OUT_DIR}/best_model/pytorch_model.bin", map_location="cpu")
model.load_state_dict(state, strict=True)
model.eval()

# Prepare test the same way:
def pack_df(df):
    return Dataset.from_list([{
        "input_ids":[int(df.loc[i,c]) for c in ID_COLS],
        "attention_mask":[int(df.loc[i,c]) for c in MASK_COLS],
        # labels/tissues optional for test
    } for i in range(len(df))])

test = pd.read_csv("test.csv", usecols=ID_COLS+MASK_COLS, dtype={**{c:"uint16" for c in ID_COLS}, **{c:"uint8" for c in MASK_COLS}})
ds_test = pack_df(test)
# then build a DataLoader and run forward passes to get logits -> probs
