
# Zero‑Shot DNABERT on ClinVar (Ref vs Alt) — **hg38**
This notebook:
1. Uses or downloads **ClinVar GRCh38 VCF** (from NCBI).
2. Downloads **UCSC hg38** FASTA (`chr1`, `chr2`, …) and indexes it.
3. Builds **ref+alt** sequence windows (±100bp) for biallelic SNVs with clear **CLNSIG** (P/LP vs B/LB).
4. Runs **zero‑shot DNABERT (v1 6‑mer)** on ref vs alt and reports AUROC/AUPRC + Accuracy/F1/Precision/Recall.


In [1]:

# === Configuration ===
import os

DATA_DIR = os.path.expanduser("/home/tstil004/phd/cs895_genai/data")    # read/write
os.makedirs(DATA_DIR, exist_ok=True)

VAL_CSV  = f"{DATA_DIR}/clinvar_seq_pairs_hg38_flank100.csv"

CLINVAR_VCF_URL = "https://ftp.ncbi.nlm.nih.gov/pub/clinvar/vcf_GRCh38/clinvar.vcf.gz"
UCSC_HG38_FASTA_URL = "http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/latest/hg38.fa.gz"

VCF_GZ  = f"{DATA_DIR}/clinvar.vcf.gz"
VCF_TBI = f"{DATA_DIR}/clinvar.vcf.gz.tbi"
FA_GZ   = f"{DATA_DIR}/hg38.fa.gz"
FA      = f"{DATA_DIR}/hg38.fa"

FLANK   = 100
CHROM_PREFIX = "chr"     # Map '1' -> 'chr1' for UCSC hg38
SEED = 42

MODEL_NAME = "zhihan1996/DNA_bert_6"
KMER = 6
MAX_LEN = 512
BATCH_SIZE = 32

print("Config OK.")


Config OK.


In [None]:

# %pip install --quiet transformers==4.44.2 torch==2.4.0 pyfaidx==0.7.2.1 pandas==2.1.4 scikit-learn==1.3.2 numpy==1.26.4 tqdm==4.66.4 pysam==0.22.0


In [2]:

import os, gzip, shutil, random
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Optional
from tqdm import tqdm
from pyfaidx import Fasta
import pysam

random.seed(SEED)
np.random.seed(SEED)

def maybe_download(url: str, dest: str):
    if not os.path.exists(dest):
        import urllib.request
        print(f"Downloading {url} -> {dest}")
        urllib.request.urlretrieve(url, dest)
    else:
        print(f"Found existing: {dest}")

def gunzip_if_needed(src_gz: str, dst: str):
    if not os.path.exists(dst) and os.path.exists(src_gz):
        print(f"Decompressing {src_gz} -> {dst}")
        with gzip.open(src_gz, 'rb') as f_in, open(dst, 'wb') as f_out:
            shutil.copyfileobj(f_in, f_out)
    else:
        if os.path.exists(dst):
            print(f"Found existing: {dst}")
        else:
            print(f"Missing source to gunzip: {src_gz}")

def index_fasta_if_needed(fa_path: str):
    fai = fa_path + ".fai"
    if not os.path.exists(fai):
        print(f"Indexing FASTA: {fa_path}")
        _ = Fasta(fa_path, as_raw=True)  # creates .fai
        _.close()
    else:
        print(f"Found FASTA index: {fai}")

def detect_csv_pairs(path:str) -> bool:
    return os.path.exists(path) and os.path.getsize(path) > 0

print("Utils OK.")


  from pkg_resources import get_distribution


Utils OK.


In [3]:

USE_EXISTING_PAIRS = detect_csv_pairs(VAL_CSV)
print("Has ref+alt pairs CSV:", USE_EXISTING_PAIRS, "|", VAL_CSV)

if not USE_EXISTING_PAIRS:
    print("No ref+alt pairs file found. Preparing inputs...")
    maybe_download(CLINVAR_VCF_URL, VCF_GZ)
    try:
        maybe_download(CLINVAR_VCF_URL + ".tbi", VCF_TBI)
    except Exception as e:
        print("Couldn't fetch prebuilt .tbi; will tabix later:", e)

    maybe_download(UCSC_HG38_FASTA_URL, FA_GZ)
    gunzip_if_needed(FA_GZ, FA)
    index_fasta_if_needed(FA)

    if not os.path.exists(VCF_TBI):
        print("Tabix-indexing ClinVar VCF...")
        pysam.tabix_index(VCF_GZ, preset="vcf", force=True)
    else:
        print("Found VCF index:", VCF_TBI)


Has ref+alt pairs CSV: False | /home/tstil004/phd/cs895_genai/data/clinvar_seq_pairs_hg38_flank100.csv
No ref+alt pairs file found. Preparing inputs...
Found existing: /home/tstil004/phd/cs895_genai/data/clinvar.vcf.gz
Found existing: /home/tstil004/phd/cs895_genai/data/clinvar.vcf.gz.tbi
Downloading http://hgdownload.soe.ucsc.edu/goldenPath/hg38/bigZips/latest/hg38.fa.gz -> /home/tstil004/phd/cs895_genai/data/hg38.fa.gz
Decompressing /home/tstil004/phd/cs895_genai/data/hg38.fa.gz -> /home/tstil004/phd/cs895_genai/data/hg38.fa
Indexing FASTA: /home/tstil004/phd/cs895_genai/data/hg38.fa
Found VCF index: /home/tstil004/phd/cs895_genai/data/clinvar.vcf.gz.tbi


In [5]:
def _norm_base(s):
    return str(s).upper().replace("U", "T")

def _chrom_with_prefix(c, prefix):
    c = str(c)
    if prefix and not c.startswith(prefix):
        return prefix + c
    return c

def _safe_parse_vcf_line(rec):
    # pysam will yield BYTES when encoding=None; handle both types safely
    if isinstance(rec, bytes):
        # decode as utf-8; ignore any stray non-utf8 bytes
        rec = rec.decode("utf-8", errors="ignore")
    rec = rec.strip()
    if not rec or rec.startswith("#"):
        return None
    parts = rec.split("\t")  # REAL tab
    if len(parts) < 8:
        return None
    chrom, pos, _id, ref, alt, qual, flt, info = parts[:8]
    return chrom, pos, _id, ref, alt, qual, flt, info

def build_ref_alt_pairs_from_vcf(
    vcf_gz: str,
    fasta_path: str,
    out_csv: str,
    flank: int = 100,
    chrom_prefix: str = "chr",
    max_records: Optional[int] = None,
):
    # Open FASTA and VCF
    fa = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)
    # CRITICAL: disable pysam's internal ascii decoding
    tbx = pysam.TabixFile(vcf_gz, encoding=None)

    rows, kept = [], 0
    print("Scanning VCF for SNVs...")

    for rec in tbx.fetch():
        parsed = _safe_parse_vcf_line(rec)
        if parsed is None:
            continue
        chrom, pos, _id, ref, alt, qual, flt, info = parsed
        try:
            pos = int(pos)
        except ValueError:
            continue

        ref = _norm_base(ref)
        alt = _norm_base(alt)

        # only biallelic 1bp SNVs
        if len(ref) != 1 or "," in alt or len(alt) != 1:
            continue

        # CLNSIG -> binary label
        label = None
        for field in info.split(";"):
            if field.startswith("CLNSIG="):
                sig = field.split("=", 1)[1].lower()
                has_path = ("pathogenic" in sig)
                has_ben  = ("benign" in sig)
                if has_path and not has_ben:
                    label = 1
                elif has_ben and not has_path:
                    label = 0
                break
        if label is None:
            continue

        c = _chrom_with_prefix(chrom, chrom_prefix)  # e.g., '1' -> 'chr1' for hg38
        start = max(1, pos - flank)
        end   = pos + flank

        try:
            ctx = str(fa[c][start:end])
        except KeyError:
            # contig naming mismatch, skip
            continue
        if len(ctx) != (end - start):
            continue

        center  = pos - start
        ref_seq = ctx
        alt_seq = ctx[:center] + alt + ctx[center+1:]

        rows.append((chrom, pos, ref, alt, ref_seq, alt_seq, label))
        kept += 1
        if max_records and kept >= max_records:
            break

    print(f"Writing pairs: {len(rows)} -> {out_csv}")
    df = pd.DataFrame(rows, columns=["CHROM", "POS", "REF", "ALT", "ref_seq", "alt_seq", "label"])
    df.to_csv(out_csv, index=False)
    return out_csv

# Build if needed
if not USE_EXISTING_PAIRS:
    build_ref_alt_pairs_from_vcf(VCF_GZ, FA, VAL_CSV, flank=FLANK, chrom_prefix=CHROM_PREFIX, max_records=None)
else:
    print("Using existing:", VAL_CSV)

df = pd.read_csv(VAL_CSV, dtype={"CHROM": str})
df.head()

Scanning VCF for SNVs...
Writing pairs: 1531207 -> /home/tstil004/phd/cs895_genai/data/clinvar_seq_pairs_hg38_flank100.csv


Unnamed: 0,CHROM,POS,REF,ALT,ref_seq,alt_seq,label
0,1,69134,A,G,AGGTAACTGCAGAGGCTATTTCCTGGAATGAATCAACGAGTGAAAC...,AGGTAACTGCAGAGGCTATTTCCTGGAATGAATCAACGAGTGAAAC...,0
1,1,924518,G,C,CCACCGGGGCGCCATGCCGGCGGTCAAGAAGGAGTTCCCGGGCCGC...,CCACCGGGGCGCCATGCCGGCGGTCAAGAAGGAGTTCCCGGGCCGC...,0
2,1,925956,C,T,CTGCCGCTGACTGCGCGCAGAAGCGTGCCGCTCCCTCACAGGGTCT...,CTGCCGCTGACTGCGCGCAGAAGCGTGCCGCTCCCTCACAGGGTCT...,0
3,1,925969,C,T,CGCGCAGAAGCGTGCCGCTCCCTCACAGGGTCTGCCTCGGCTCTGC...,CGCGCAGAAGCGTGCCGCTCCCTCACAGGGTCTGCCTCGGCTCTGC...,0
4,1,925980,C,T,GTGCCGCTCCCTCACAGGGTCTGCCTCGGCTCTGCTCGCAGGGAAA...,GTGCCGCTCCCTCACAGGGTCTGCCTCGGCTCTGCTCGCAGGGAAA...,0


In [6]:

# Labels + split
from sklearn.model_selection import train_test_split
import numpy as np

if "label" not in df.columns:
    if "LABEL" in df.columns:
        df["label"] = df["LABEL"].astype(int)
    elif "CLNSIG" in df.columns:
        def clinsig_to_label(v):
            s = str(v).lower()
            has_path = ("pathogenic" in s); has_ben = ("benign" in s)
            if has_path and not has_ben: return 1
            if has_ben  and not has_path: return 0
            return np.nan
        df["label"] = df["CLNSIG"].apply(clinsig_to_label).astype("float")
    else:
        raise KeyError("No label/LABEL/CLNSIG in dataframe.")

before = len(df)
df = df.dropna(subset=["label"]).copy()
df["label"] = df["label"].astype(int)
after = len(df)
print(f"Using {after}/{before} rows with binary labels.")
assert {"ref_seq","alt_seq"}.issubset(df.columns)

df_train, df_tmp = train_test_split(df, test_size=0.4, random_state=SEED, stratify=df["label"])
df_val, df_test  = train_test_split(df_tmp, test_size=0.5, random_state=SEED, stratify=df_tmp["label"])
print(df_train.shape, df_val.shape, df_test.shape)
print("Label balance (val): ", df_val['label'].value_counts(normalize=True).to_dict())
print("Label balance (test):", df_test['label'].value_counts(normalize=True).to_dict())


Using 1531207/1531207 rows with binary labels.
(918724, 7) (306241, 7) (306242, 7)
Label balance (val):  {0: 0.7888884897841896, 1: 0.21111151021581043}
Label balance (test): {0: 0.7888891791459042, 1: 0.21111082085409577}


In [8]:

# Zero-shot DNABERT (v1)
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertConfig, BertTokenizerFast
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_fscore_support, accuracy_score

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "zhihan1996/DNA_bert_6"
KMER = 6; MAX_LEN = 512; BATCH_SIZE = 32

class SNVDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        return {"ref_seq": r["ref_seq"], "alt_seq": r["alt_seq"], "labels": int(r["label"])}

def _kmerize(seq: str, k: int) -> str:
    seq = str(seq).strip().upper().replace("U","T")
    if k is None: return seq
    if len(seq) < k: seq = seq + ("N" * (k - len(seq)))
    return " ".join(seq[i:i+k] for i in range(0, len(seq)-k+1))

def _build_tokenizer(model_name: str):
    return BertTokenizerFast.from_pretrained(model_name, do_lower_case=False)

@torch.no_grad()
def _encode_pooled(encoder, input_ids, attention_mask):
    out = encoder(input_ids=input_ids, attention_mask=attention_mask)
    hs = out.last_hidden_state
    mask = attention_mask.unsqueeze(-1).type_as(hs)
    summed = (hs * mask).sum(dim=1); denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

def _tokenize_batch(tokenizer, seqs, kmer, max_len):
    texts = [_kmerize(s, kmer) for s in seqs]
    toks = tokenizer(texts, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
    return toks["input_ids"], toks["attention_mask"]

@torch.no_grad()
def zero_shot_scores(ds, batch_size=BATCH_SIZE, device=DEVICE):
    cfg = BertConfig.from_pretrained(MODEL_NAME)
    base_enc = BertModel.from_pretrained(MODEL_NAME, config=cfg).to(device)
    base_enc.eval()
    tokenizer = _build_tokenizer(MODEL_NAME)

    def collate(batch):
        ref_ids, ref_msk = _tokenize_batch(tokenizer, [b["ref_seq"] for b in batch], KMER, MAX_LEN)
        alt_ids, alt_msk = _tokenize_batch(tokenizer, [b["alt_seq"] for b in batch], KMER, MAX_LEN)
        y = torch.tensor([int(b["labels"]) for b in batch], dtype=torch.long)
        return {"ref_input_ids": ref_ids, "ref_attention_mask": ref_msk,
                "alt_input_ids": alt_ids, "alt_attention_mask": alt_msk, "labels": y}

    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
    all_cos, all_l2, all_y = [], [], []
    for b in tqdm(loader, desc="Zero-shot scoring"):
        ref_ids = b["ref_input_ids"].to(device); ref_msk = b["ref_attention_mask"].to(device)
        alt_ids = b["alt_input_ids"].to(device); alt_msk = b["alt_attention_mask"].to(device)
        y = b["labels"].cpu().numpy()

        ref_repr = _encode_pooled(base_enc, ref_ids, ref_msk)
        alt_repr = _encode_pooled(base_enc, alt_ids, alt_msk)

        ref_norm = torch.nn.functional.normalize(ref_repr, dim=-1)
        alt_norm = torch.nn.functional.normalize(alt_repr, dim=-1)
        cos_sim = (ref_norm * alt_norm).sum(dim=-1)
        cos_dist = (1 - cos_sim).cpu().numpy()
        l2 = torch.norm(alt_repr - ref_repr, dim=-1).cpu().numpy()

        all_cos.append(cos_dist); all_l2.append(l2); all_y.append(y)

    return {"cos_dist": np.concatenate(all_cos),
            "l2": np.concatenate(all_l2),
            "y": np.concatenate(all_y).astype(int)}

def _cls_metrics(y_true, scores, thr):
    preds = (scores >= thr).astype(int)
    acc = accuracy_score(y_true, preds)
    p, r, f1, _ = precision_recall_fscore_support(y_true, preds, average="binary", zero_division=0)
    return {"accuracy": float(acc), "f1": float(f1), "precision": float(p), "recall": float(r)}

def _best_thr_by_f1(y_true, scores, n=501):
    lo, hi = float(scores.min()), float(scores.max())
    lo, hi = lo - 1e-6, hi + 1e-6
    thrs = np.linspace(lo, hi, n)
    best = {"thr": 0.5, "f1": -1.0, "precision": 0.0, "recall": 0.0}
    for t in thrs:
        m = _cls_metrics(y_true, scores, t)
        if m["f1"] > best["f1"]:
            best = {"thr": float(t), **m}
    return best

def zero_shot_eval_full(val_ds, test_ds, batch_size=BATCH_SIZE):
    val = zero_shot_scores(val_ds, batch_size=max(8, batch_size))
    test = zero_shot_scores(test_ds, batch_size=max(8, batch_size))
    report = {}
    for name in ["cos_dist", "l2"]:
        report[("val", name, "auroc")] = roc_auc_score(val["y"], val[name]) if len(set(val["y"]))>1 else float("nan")
        report[("val", name, "auprc")] = average_precision_score(val["y"], val[name]) if len(set(val["y"]))>1 else float("nan")
        report[("test", name, "auroc")] = roc_auc_score(test["y"], test[name]) if len(set(test["y"]))>1 else float("nan")
        report[("test", name, "auprc")] = average_precision_score(test["y"], test[name]) if len(set(test["y"]))>1 else float("nan")
        best = _best_thr_by_f1(val["y"], val[name])
        val_m = _cls_metrics(val["y"], val[name], best["thr"]); tst_m = _cls_metrics(test["y"], test[name], best["thr"])
        report[("val", name, "thr")] = best["thr"]
        for k, v in val_m.items(): report[("val", name, k)] = v
        for k, v in tst_m.items(): report[("test", name, k)] = v

    def _fmt(split, metric):
        cd = report[(split, "cos_dist", metric)]; l2 = report[(split, "l2", metric)]
        return f"cos={cd:.3f} | l2={l2:.3f}"
    print("[Zero-shot: AUROC]   val:", _fmt("val","auroc"), "  test:", _fmt("test","auroc"))
    print("[Zero-shot: AUPRC]   val:", _fmt("val","auprc"), "  test:", _fmt("test","auprc"))
    print(f"[Zero-shot: Thr(F1) ] val: cos={report[('val','cos_dist','thr')]:.4f} | l2={report[('val','l2','thr')]:.4f}")
    for m in ["accuracy","f1","precision","recall"]:
        print(f"[Zero-shot: {m.title():7}] val: {_fmt('val', m)}  test: {_fmt('test', m)}")
    return report


In [9]:

# Execute zero-shot
val_ds  = SNVDataset(df_val)
test_ds = SNVDataset(df_test)
zs_report = zero_shot_eval_full(val_ds, test_ds, batch_size=BATCH_SIZE)
zs_report


config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/359M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Zero-shot scoring: 100%|██████████| 9571/9571 [29:25<00:00,  5.42it/s]
Zero-shot scoring: 100%|██████████| 9571/9571 [29:21<00:00,  5.43it/s]


[Zero-shot: AUROC]   val: cos=0.524 | l2=0.520   test: cos=0.524 | l2=0.521
[Zero-shot: AUPRC]   val: cos=0.224 | l2=0.223   test: cos=0.224 | l2=0.224
[Zero-shot: Thr(F1) ] val: cos=-0.0000 | l2=-0.0000
[Zero-shot: Accuracy] val: cos=0.211 | l2=0.211  test: cos=0.211 | l2=0.211
[Zero-shot: F1     ] val: cos=0.349 | l2=0.349  test: cos=0.349 | l2=0.349
[Zero-shot: Precision] val: cos=0.211 | l2=0.211  test: cos=0.211 | l2=0.211
[Zero-shot: Recall ] val: cos=1.000 | l2=1.000  test: cos=1.000 | l2=1.000


{('val', 'cos_dist', 'auroc'): 0.523705824967194,
 ('val', 'cos_dist', 'auprc'): 0.2239033612626116,
 ('test', 'cos_dist', 'auroc'): 0.523583753264157,
 ('test', 'cos_dist', 'auprc'): 0.2243609812005871,
 ('val', 'cos_dist', 'thr'): -1.2384185791015625e-06,
 ('val', 'cos_dist', 'accuracy'): 0.21111151021581043,
 ('val', 'cos_dist', 'f1'): 0.3486243973987037,
 ('val', 'cos_dist', 'precision'): 0.21111151021581043,
 ('val', 'cos_dist', 'recall'): 1.0,
 ('test', 'cos_dist', 'accuracy'): 0.21111082085409577,
 ('test', 'cos_dist', 'f1'): 0.34862345743920753,
 ('test', 'cos_dist', 'precision'): 0.21111082085409577,
 ('test', 'cos_dist', 'recall'): 1.0,
 ('val', 'l2', 'auroc'): 0.5203423088666612,
 ('val', 'l2', 'auprc'): 0.222761176965432,
 ('test', 'l2', 'auroc'): 0.5212464971094904,
 ('test', 'l2', 'auprc'): 0.22370426651033742,
 ('val', 'l2', 'thr'): -1e-06,
 ('val', 'l2', 'accuracy'): 0.21111151021581043,
 ('val', 'l2', 'f1'): 0.3486243973987037,
 ('val', 'l2', 'precision'): 0.2111115102