
# ClinVar Variant Effect Prediction with DNABERT — Three Fine-Tuning Setups

This notebook runs **three** supervised approaches on the same train/val/test split and then prints a side‑by‑side comparison:

1. **Joint Paired Input** (single pass): `BertForSequenceClassification` with `(ref, alt)` passed as a **paired** sequence (segment IDs 0/1).  
2. **Dual-Input (Concat)**: Shared encoder encodes ref and alt **separately**, concatenates pooled embeddings → MLP classifier.  
3. **Siamese (Delta)**: Shared encoder encodes ref and alt separately and fuses **[ref, alt, (alt−ref), |alt−ref|, ref⊙alt]** → MLP classifier.

> Assumes you already have a prepared CSV with columns: `ref_seq`, `alt_seq`, `label` (e.g., from the hg38 prep notebook).


In [None]:

# === Configuration ===
import os

VAL_CSV = os.path.expanduser("~/data/clinvar_seq_pairs_hg38_flank100.csv")  # change if needed
MODEL_NAME = "zhihan1996/DNA_bert_6"   # DNABERT v1 (6-mer)
KMER       = 6
MAX_LEN    = 512
SEED       = 42

# Training
TRAIN_BATCH_SIZE = 16
EVAL_BATCH_SIZE  = 32
LR               = 3e-5
WEIGHT_DECAY     = 0.01
NUM_EPOCHS       = 3
WARMUP_RATIO     = 0.06
FP16             = True

OUT_ROOT = "./dnabert_all_ft"
os.makedirs(OUT_ROOT, exist_ok=True)

print("Config OK.")


In [None]:

# If needed, install deps:
# %pip install --quiet transformers==4.44.2 torch==2.4.0 pandas==2.1.4 scikit-learn==1.3.2 numpy==1.26.4 tqdm==4.66.4 evaluate==0.4.1


In [None]:

import pandas as pd, numpy as np
from sklearn.model_selection import train_test_split

df = pd.read_csv(VAL_CSV, dtype={"CHROM": str})
assert {"ref_seq","alt_seq"}.issubset(df.columns), f"Missing sequence cols in {VAL_CSV}"

# Ensure binary label
if "label" not in df.columns:
    if "LABEL" in df.columns:
        df["label"] = df["LABEL"].astype(int)
    elif "CLNSIG" in df.columns:
        def clinsig_to_label(v):
            s = str(v).lower()
            has_path=("pathogenic" in s); has_ben=("benign" in s)
            if has_path and not has_ben: return 1
            if has_ben and not has_path: return 0
            return np.nan
        df["label"] = df["CLNSIG"].apply(clinsig_to_label).astype("float")
    else:
        raise KeyError("Need label/LABEL/CLNSIG in CSV to derive labels.")

before=len(df)
df = df.dropna(subset=["label"]).copy()
df["label"] = df["label"].astype(int)
after=len(df)
print(f"Using {after}/{before} labeled rows.")
print("Class balance:", df['label'].value_counts(normalize=True).to_dict())

# Freeze a single split reused by all models
df_train, df_tmp = train_test_split(df, test_size=0.2, random_state=SEED, stratify=df["label"])
df_val, df_test  = train_test_split(df_tmp, test_size=0.5, random_state=SEED, stratify=df_tmp["label"])
print("Splits:", df_train.shape, df_val.shape, df_test.shape)


In [None]:

import torch, math
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import BertTokenizerFast, BertForSequenceClassification, BertModel, BertConfig
from transformers import TrainingArguments, Trainer
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score, precision_recall_fscore_support

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME, do_lower_case=False)

def kmerize(seq, k=KMER):
    s = str(seq).strip().upper().replace("U","T")
    if len(s) < k: s += "N"*(k-len(s))
    return " ".join(s[i:i+k] for i in range(0, len(s)-k+1))

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    probs = torch.tensor(logits).softmax(dim=-1).numpy()
    preds = probs.argmax(axis=1)
    pos = probs[:,1]
    acc = accuracy_score(labels, preds)
    p, r, f1, _ = precision_recall_fscore_support(labels, preds, average="binary", zero_division=0)
    try: auroc = roc_auc_score(labels, pos)
    except: auroc = float("nan")
    try: auprc = average_precision_score(labels, pos)
    except: auprc = float("nan")
    return {"accuracy": acc, "precision": p, "recall": r, "f1": f1, "auroc": auroc, "auprc": auprc}

# Imbalance weights from TRAIN only
classes = np.array([0,1])
weights = compute_class_weight(class_weight="balanced", classes=classes, y=df_train["label"].values)
class_weights = torch.tensor(weights, dtype=torch.float)
print("Class weights:", class_weights.tolist())


In [None]:

# Dataset A: Joint paired input (single pass with token_type_ids distinguishing ref/alt)
class JointPairDataset(Dataset):
    def __init__(self, frame):
        self.df = frame.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        a = kmerize(r["ref_seq"]); b = kmerize(r["alt_seq"])
        enc = tokenizer(a, b, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(int(r["label"]), dtype=torch.long)
        return item

# Dataset B: Separate encodings (for Dual-Input & Siamese)
class SepPairDataset(Dataset):
    def __init__(self, frame):
        self.df = frame.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        ref_txt = kmerize(r["ref_seq"]); alt_txt = kmerize(r["alt_seq"])
        ref_enc = tokenizer(ref_txt, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
        alt_enc = tokenizer(alt_txt, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt")
        item = {
            "ref_input_ids": ref_enc["input_ids"].squeeze(0),
            "ref_attention_mask": ref_enc["attention_mask"].squeeze(0),
            "alt_input_ids": alt_enc["input_ids"].squeeze(0),
            "alt_attention_mask": alt_enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(int(r["label"]), dtype=torch.long),
        }
        if "token_type_ids" in ref_enc:
            item["ref_token_type_ids"] = ref_enc["token_type_ids"].squeeze(0)
            item["alt_token_type_ids"] = alt_enc["token_type_ids"].squeeze(0)
        return item

train_joint, val_joint, test_joint = JointPairDataset(df_train), JointPairDataset(df_val), JointPairDataset(df_test)
train_sep,   val_sep,  test_sep    = SepPairDataset(df_train),  SepPairDataset(df_val),  SepPairDataset(df_test)

len(train_joint), len(val_joint), len(test_joint)


## Step 1 — Fine-tune: Joint Paired Input (single-pass BERT)

In [None]:

from torch.nn import CrossEntropyLoss

joint_model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

args_joint = TrainingArguments(
    output_dir=f"{OUT_ROOT}/joint",
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=NUM_EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="auroc",
    greater_is_better=True,
    fp16=FP16,
    logging_steps=100,
    seed=SEED,
    dataloader_num_workers=2,
)

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = CrossEntropyLoss(weight=class_weights.to(logits.device))
        loss = loss_fct(logits.view(-1, model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

trainer_joint = WeightedTrainer(
    model=joint_model,
    args=args_joint,
    train_dataset=train_joint,
    eval_dataset=val_joint,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("Training joint model...")
joint_train = trainer_joint.train()
joint_val   = trainer_joint.evaluate(eval_dataset=val_joint)
joint_test  = trainer_joint.evaluate(eval_dataset=test_joint)
print("Joint Val:", joint_val)
print("Joint Test:", joint_test)

# store for summary
results = []
results.append({"model":"joint", **{f"val_{k}":v for k,v in joint_val.items()}, **{f"test_{k}":v for k,v in joint_test.items()}})


## Step 2 — Fine-tune: Dual-Input (Concat) with Shared Encoder

In [None]:

def masked_mean(hs, mask):
    mask = mask.unsqueeze(-1).type_as(hs)
    summed = (hs * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

class SiameseFusion(nn.Module):
    def __init__(self, model_name=MODEL_NAME, num_labels=2, fusion="concat", dropout=0.1):
        super().__init__()
        self.cfg = BertConfig.from_pretrained(model_name)
        self.enc = BertModel.from_pretrained(model_name, config=self.cfg)
        self.fusion = fusion
        hidden = self.cfg.hidden_size
        if fusion == "concat":
            in_dim = hidden * 2
        elif fusion == "delta":
            in_dim = hidden * 5   # [ref, alt, Δ, |Δ|, ref⊙alt]
        else:
            raise ValueError("fusion must be 'concat' or 'delta'")
        self.head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_labels)
        )

    def forward(self, **batch):
        ref_ids = batch["ref_input_ids"]; ref_msk = batch["ref_attention_mask"]
        alt_ids = batch["alt_input_ids"]; alt_msk = batch["alt_attention_mask"]
        ref_tti = batch.get("ref_token_type_ids", None)
        alt_tti = batch.get("alt_token_type_ids", None)

        ref_out = self.enc(input_ids=ref_ids, attention_mask=ref_msk, token_type_ids=ref_tti)
        alt_out = self.enc(input_ids=alt_ids, attention_mask=alt_msk, token_type_ids=alt_tti)

        ref_pool = masked_mean(ref_out.last_hidden_state, ref_msk)
        alt_pool = masked_mean(alt_out.last_hidden_state, alt_msk)

        if self.fusion == "concat":
            feat = torch.cat([ref_pool, alt_pool], dim=-1)
        else:
            delta = alt_pool - ref_pool
            feat = torch.cat([ref_pool, alt_pool, delta, delta.abs(), ref_pool * alt_pool], dim=-1)

        logits = self.head(feat)
        return {"logits": logits, "labels": batch["labels"]}

# Concat model
concat_model = SiameseFusion(fusion="concat")

args_concat = TrainingArguments(
    output_dir=f"{OUT_ROOT}/concat",
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=NUM_EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="auroc",
    greater_is_better=True,
    fp16=FP16,
    logging_steps=100,
    seed=SEED,
    dataloader_num_workers=2,
)

class WeightedTrainer2(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs["labels"]
        outputs = model(**inputs)
        logits = outputs["logits"]
        loss = nn.CrossEntropyLoss(weight=class_weights.to(logits.device))(logits, labels)
        return (loss, outputs) if return_outputs else loss

trainer_concat = WeightedTrainer2(
    model=concat_model,
    args=args_concat,
    train_dataset=train_sep,
    eval_dataset=val_sep,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("Training Dual-Input (Concat)...")
concat_train = trainer_concat.train()
concat_val   = trainer_concat.evaluate(eval_dataset=val_sep)
concat_test  = trainer_concat.evaluate(eval_dataset=test_sep)
print("Concat Val:", concat_val)
print("Concat Test:", concat_test)

results.append({"model":"concat", **{f"val_{k}":v for k,v in concat_val.items()}, **{f"test_{k}":v for k,v in concat_test.items()}})


## Step 3 — Fine-tune: Siamese (Delta) with Shared Encoder

In [None]:

# Reuse SiameseFusion with fusion='delta'
siamese_model = SiameseFusion(fusion="delta")

args_siamese = TrainingArguments(
    output_dir=f"{OUT_ROOT}/siamese",
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH_SIZE,
    learning_rate=LR,
    weight_decay=WEIGHT_DECAY,
    num_train_epochs=NUM_EPOCHS,
    warmup_ratio=WARMUP_RATIO,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="auroc",
    greater_is_better=True,
    fp16=FP16,
    logging_steps=100,
    seed=SEED,
    dataloader_num_workers=2,
)

trainer_siamese = WeightedTrainer2(
    model=siamese_model,
    args=args_siamese,
    train_dataset=train_sep,
    eval_dataset=val_sep,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

print("Training Siamese (Delta)...")
siam_train = trainer_siamese.train()
siam_val   = trainer_siamese.evaluate(eval_dataset=val_sep)
siam_test  = trainer_siamese.evaluate(eval_dataset=test_sep)
print("Siamese Val:", siam_val)
print("Siamese Test:", siam_test)

results.append({"model":"siamese_delta", **{f"val_{k}":v for k,v in siam_val.items()}, **{f"test_{k}":v for k,v in siam_test.items()}})


## Comparison Table

In [None]:

import pandas as pd
cmp = pd.DataFrame(results)

# Keep the most important columns
cols = ["model",
        "val_auroc","test_auroc",
        "val_auprc","test_auprc",
        "val_accuracy","test_accuracy",
        "val_f1","test_f1",
        "val_precision","test_precision",
        "val_recall","test_recall"]
existing = [c for c in cols if c in cmp.columns]
cmp = cmp[existing].copy()

display(cmp)
cmp.to_csv(f"{OUT_ROOT}/comparison_metrics.csv", index=False)
print("Saved:", f"{OUT_ROOT}/comparison_metrics.csv")
