In [None]:
!pip -q install -U transformers datasets accelerate peft


In [None]:
import os, re, dataclasses, torch
os.environ["CUDA_VISIBLE_DEVICES"] = ""   # ensure CPU only

from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
)
from peft import LoraConfig, get_peft_model, TaskType

# ---------- Paths ----------
DATA_PATH = "/content/merged_corpus_cleaned_Final.jsonl"  # your repaired JSONL
OUT_DIR   = "/content/biobart_5k_cpu_lora"

# ---------- Model ----------
MODEL_ID  = "facebook/bart-base"   # swap to your BioBART if you have one

# ---------- Sizes ----------
SAMPLE_N        = 5_000            # use 5k rows total
VAL_FRAC        = 0.05

# ---------- Token lengths ----------
MAX_SOURCE_LEN  = 256
MAX_TARGET_LEN  = 16

# ---------- HParams (CPU-sane) ----------
EPOCHS          = 1
PER_DEVICE_BATCH= 8                # drop to 4 if memory is tight
GRAD_ACC        = 1
LEARNING_RATE   = 3e-5
WEIGHT_DECAY    = 0.01
WARMUP_RATIO    = 0.03
SEED            = 42

# Make PyTorch play nice on Colab CPU
torch.set_num_threads(4)  # you can try 2 or 8 depending on your session
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")


Device: cpu


In [None]:
def _norm(s: str) -> str:
    return re.sub(r"\s+"," ", (s or "").strip())

def extract_short_answer(ans: str) -> str:
    s = _norm(ans or "")
    if not s: return ""
    s = re.sub(r"(?:^|\b)Ref(?:erences?)?:.*$", "", s, flags=re.IGNORECASE).strip()
    if not s: return ""
    m = re.match(r"^[A-Z]\.\s*([^\n\.]{1,120})", s)  # e.g., "B. Fluoxetine"
    if m: short = m.group(1)
    else:
        m2 = re.match(r"^(.{1,120}?)([\.!\?]|$)", s)
        short = m2.group(1) if m2 else s[:120]
    short = short.split(" - ")[0].split(" (")[0].strip(" .,:;-'\"|[]()")
    return " ".join(short.split()[:12])

raw = load_dataset("json", data_files={"train": DATA_PATH})["train"]
raw = raw.filter(lambda ex: bool(ex.get("question")) and bool(ex.get("answer")))

def add_cols(ex):
    ans = str(ex.get("answer") or "")
    return {"context": _norm(ans), "label": extract_short_answer(ans)}

raw = raw.map(add_cols)
raw = raw.filter(lambda ex: len(ex["context"])>0 and len(ex["label"])>0)

raw = raw.shuffle(seed=SEED)
N = min(SAMPLE_N, len(raw))
raw_5k = raw.select(range(N))

ds = raw_5k.train_test_split(test_size=VAL_FRAC, seed=SEED)
ds_train, ds_val = ds["train"], ds["test"]
print(f"Sampled: {len(raw_5k)} | Train: {len(ds_train)} | Val: {len(ds_val)}")


Generating train split: 0 examples [00:00, ? examples/s]

Filter:   0%|          | 0/183030 [00:00<?, ? examples/s]

Map:   0%|          | 0/183028 [00:00<?, ? examples/s]

Filter:   0%|          | 0/183028 [00:00<?, ? examples/s]

Sampled: 5000 | Train: 4750 | Val: 250


In [None]:
def _norm(s: str) -> str:
    return re.sub(r"\s+"," ", (s or "").strip())

def extract_short_answer(ans: str) -> str:
    s = _norm(ans or "")
    if not s: return ""
    s = re.sub(r"(?:^|\b)Ref(?:erences?)?:.*$", "", s, flags=re.IGNORECASE).strip()
    if not s: return ""
    m = re.match(r"^[A-Z]\.\s*([^\n\.]{1,120})", s)  # e.g., "B. Fluoxetine"
    if m: short = m.group(1)
    else:
        m2 = re.match(r"^(.{1,120}?)([\.!\?]|$)", s)
        short = m2.group(1) if m2 else s[:120]
    short = short.split(" - ")[0].split(" (")[0].strip(" .,:;-'\"|[]()")
    return " ".join(short.split()[:12])

raw = load_dataset("json", data_files={"train": DATA_PATH})["train"]
raw = raw.filter(lambda ex: bool(ex.get("question")) and bool(ex.get("answer")))

def add_cols(ex):
    ans = str(ex.get("answer") or "")
    return {"context": _norm(ans), "label": extract_short_answer(ans)}

raw = raw.map(add_cols)
raw = raw.filter(lambda ex: len(ex["context"])>0 and len(ex["label"])>0)

raw = raw.shuffle(seed=SEED)
N = min(SAMPLE_N, len(raw))
raw_5k = raw.select(range(N))

ds = raw_5k.train_test_split(test_size=VAL_FRAC, seed=SEED)
ds_train, ds_val = ds["train"], ds["test"]
print(f"Sampled: {len(raw_5k)} | Train: {len(ds_train)} | Val: {len(ds_val)}")


Sampled: 5000 | Train: 4750 | Val: 250


In [None]:
INSTRUCTION = ("Use ONLY the context to answer. "
               "If the information is not present, reply exactly: Insufficient context.")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

def _safe_norm(x): return _norm(str(x)) if x is not None else ""

def _build_head(batch, i):
    tags = []
    if "dataset" in batch and batch["dataset"][i]: tags.append(f"[{_safe_norm(batch['dataset'][i])}]")
    if "subject" in batch and batch["subject"][i]: tags.append(f"Subject: {_safe_norm(batch['subject'][i])}")
    if "topic"   in batch and batch["topic"][i]:   tags.append(f"Topic: {_safe_norm(batch['topic'][i])}")
    return " ".join(tags)

def make_src_tgt(batch):
    n = len(batch["question"])
    srcs, tgts = [], []
    for i in range(n):
        ctx  = batch["context"][i]
        q    = _safe_norm(batch["question"][i])
        head = _build_head(batch, i)
        srcs.append(f"{INSTRUCTION}\n\n{head}\nContext:\n{ctx}\n\nQuestion:\n{q}\n\nAnswer:")
        tgts.append(_safe_norm(batch["label"][i]))
    return srcs, tgts

def preprocess(batch):
    srcs, tgts = make_src_tgt(batch)
    model_inputs = tokenizer(srcs, max_length=MAX_SOURCE_LEN, truncation=True, padding=False)
    labels = tokenizer(text_target=tgts, max_length=MAX_TARGET_LEN, truncation=True, padding=False)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

ds_train_tok = ds_train.map(preprocess, batched=True, batch_size=1024,
                            remove_columns=ds_train.column_names)
ds_val_tok   = ds_val.map(preprocess,   batched=True, batch_size=1024,
                            remove_columns=ds_val.column_names)

collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=None, pad_to_multiple_of=8)
print("Tokenized:", len(ds_train_tok), len(ds_val_tok))


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/4750 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

Tokenized: 4750 250


In [None]:
# Base model
base_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

# LoRA config (seq2seq)
lora_cfg = LoraConfig(
    r=8, lora_alpha=16, lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
    target_modules=["q_proj","v_proj"]  # common for BART
)
model = get_peft_model(base_model, lora_cfg)
model.print_trainable_parameters()

# TrainingArguments — make it robust to your transformers version (4.56.1 uses eval_strategy)
fields_set = {f.name for f in dataclasses.fields(Seq2SeqTrainingArguments)}
kwargs = dict(
    output_dir=OUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH,
    per_device_eval_batch_size=PER_DEVICE_BATCH,
    gradient_accumulation_steps=GRAD_ACC,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    lr_scheduler_type="cosine",
    warmup_ratio=WARMUP_RATIO,
    logging_steps=50,
    save_strategy="no",              # fastest; add steps if you want checkpoints
    report_to="none",
    seed=SEED,
    predict_with_generate=False,
    dataloader_num_workers=2,
    no_cuda=True,                    # force CPU even if a GPU appears
    optim="adamw_torch",             # avoid fused variants
)

# Handle eval strategy field name
if "eval_strategy" in fields_set:
    kwargs["eval_strategy"] = "no"
elif "evaluation_strategy" in fields_set:
    kwargs["evaluation_strategy"] = "no"

train_args = Seq2SeqTrainingArguments(**{k:v for k,v in kwargs.items() if k in fields_set})

trainer = Seq2SeqTrainer(
    model=model,
    args=train_args,
    train_dataset=ds_train_tok,
    eval_dataset=ds_val_tok,         # keep small eval; remove if you want every bit of speed
    processing_class=tokenizer,
    data_collator=collator,
)

train_out = trainer.train()
print({k: v for k,v in train_out.metrics.items() if isinstance(v,(int,float))})


model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

trainable params: 442,368 || all params: 139,862,784 || trainable%: 0.3163




Step,Training Loss
50,4.5436
100,3.632
150,2.3123
200,1.4957
250,0.9367
300,0.5834
350,0.5608
400,0.4134
450,0.437
500,0.357


{'train_runtime': 6000.9856, 'train_samples_per_second': 0.792, 'train_steps_per_second': 0.099, 'total_flos': 700303290531840.0, 'train_loss': 1.3430684455717452, 'epoch': 1.0}


In [None]:
trainer.save_model(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)
print("Saved to:", OUT_DIR)

@torch.no_grad()
def infer(question: str, context: str, max_new_tokens: int = 48):
    prompt = (f"{INSTRUCTION}\n\nContext:\n{_norm(context)}\n\nQuestion:\n{_norm(question)}\n\nAnswer:")
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SOURCE_LEN)
    out = model.generate(**enc, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1,
                         eos_token_id=tokenizer.eos_token_id)
    txt = tokenizer.decode(out[0], skip_special_tokens=True)
    return txt.split("Answer:")[-1].strip()

print(infer(
    "Do 5 mg Zolmitriptan tablets contain gluten?",
    "Zolmitriptan tablets contain lactose, microcrystalline cellulose, sodium starch glycolate, magnesium stearate..."
))


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Saved to: /content/biobart_5k_cpu_lora
Zol


In [None]:
# If these exist from training, keep them. Otherwise set them here.
MODEL_ID   = globals().get("MODEL_ID")  # or "GanjinZero/biobart-base"
ADAPTER_DIR= globals().get("OUT_DIR")  # your fine-tuned output folder
DATA_PATH  = globals().get("DATA_PATH")

# Inference knobs
MAX_SOURCE_LEN  = 512       # more room for concatenated contexts
MAX_NEW_TOKENS  = 64
TOP_K           = 3         # concatenate top-k retrieved contexts
SIM_THRESHOLD   = 0.25      # cosine similarity threshold for "good enough" context
ALLOW_FALLBACK  = True      # use general medical knowledge if retrieval is weak
SEED            = 42


In [None]:
!pip -q install -U rank-bm25 sentence-transformers
# optional: cross-encoder (commented by default)
!pip -q install -U cross-encoder



[31mERROR: Could not find a version that satisfies the requirement cross-encoder (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for cross-encoder[0m[31m
[0m

In [None]:
import pathlib, torch, re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

device = "cuda" if torch.cuda.is_available() else "cpu"

def _norm(s: str) -> str:
    return re.sub(r"\s+"," ", (s or "").strip())

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

def load_model(base_id, adapter_dir):
    p = pathlib.Path(adapter_dir)
    has_peft = any((p/n).exists() for n in ["adapter_config.json","adapter_model.bin","adapter_model.safetensors"])
    base = AutoModelForSeq2SeqLM.from_pretrained(base_id)
    if has_peft:
        print(f"Using LoRA adapters from: {adapter_dir}")
        model = PeftModel.from_pretrained(base, adapter_dir)
    else:
        print("No adapters found; using base weights.")
        model = base
    return model.eval().to(device)

model = load_model(MODEL_ID, ADAPTER_DIR)


Using LoRA adapters from: /content/biobart_5k_cpu_lora


In [None]:
import re, json, numpy as np
from datasets import load_dataset
from rank_bm25 import BM25Okapi

 # set your path

def _norm(s: str) -> str:
    return re.sub(r"\s+"," ", (s or "").strip())

def _tok(s: str):
    # simple, fast analyzer: lowercase, keep words/numbers/hyphens
    return re.findall(r"[A-Za-z0-9][A-Za-z0-9\-]+", s.lower())

raw = load_dataset("json", data_files={"train": DATA_PATH})["train"]
raw = raw.filter(lambda ex: bool(ex.get("answer")))
ANSWERS = [_norm(a) for a in raw["answer"]]
IDS     = raw["id"]

# Deduplicate identical contexts to reduce index size
seen, uniq_answers, uniq_ids = set(), [], []
for i, ctx in enumerate(ANSWERS):
    if ctx not in seen:
        seen.add(ctx)
        uniq_answers.append(ctx)
        uniq_ids.append(IDS[i])

# Tokenize corpus once
DOCS_TOK = [_tok(t) for t in uniq_answers]

# Build BM25 (fast on CPU)
bm25 = BM25Okapi(DOCS_TOK)
print(f"BM25 index ready | contexts: {len(uniq_answers)}")


BM25 index ready | contexts: 171944


In [None]:
from sentence_transformers import SentenceTransformer

# Small, fast CPU embedding model
embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

USE_CROSS_ENCODER = False  # set True to enable final re-rank
if USE_CROSS_ENCODER:
    from cross_encoder import CrossEncoder
    xenc = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")  # fast, decent accuracy

def retrieve_candidates(question: str, k_bm25=800, k_dense=120, k_cross=24, k_final=6):
    """
    1) BM25 gets top-k_bm25 candidates quickly (lexical recall).
    2) Dense encoder re-ranks to k_dense (semantic).
    3) (Optional) Cross-encoder re-ranks to k_final (interaction).
    """
    q_tok = _tok(question)
    # BM25
    scores_bm25 = bm25.get_scores(q_tok)
    if k_bm25 < len(scores_bm25):
        idx = np.argpartition(-scores_bm25, k_bm25)[:k_bm25]
        idx = idx[np.argsort(-scores_bm25[idx])]
    else:
        idx = np.argsort(-scores_bm25)

    # Dense re-rank on these candidates
    cand_ctx = [uniq_answers[i] for i in idx]
    q_emb    = embedder.encode([_norm(question)], normalize_embeddings=True)
    c_embs   = embedder.encode(cand_ctx, batch_size=128, normalize_embeddings=True)
    dense_scores = (c_embs @ q_emb.T).ravel()  # cosine since normalized

    # take top k_dense
    order = np.argsort(-dense_scores)[:min(k_dense, len(dense_scores))]
    idx   = idx[order]
    cand_ctx = [uniq_answers[i] for i in idx]

    if USE_CROSS_ENCODER:
        # Cross-encode top k_dense for best k_final
        pairs  = [(question, c) for c in cand_ctx]
        xscore = xenc.predict(pairs, batch_size=32)
        order2 = np.argsort(-xscore)[:min(k_final, len(xscore))]
        idx    = idx[order2]
        cand_ctx = [uniq_answers[i] for i in idx]
        scores = xscore[order2]
    else:
        # No cross-encoder: keep dense ranking; cut to k_final
        order2 = np.argsort(-dense_scores[order])[:min(k_final, len(order))]
        idx    = idx[order2]
        cand_ctx = [uniq_answers[i] for i in idx]
        scores   = dense_scores[order][order2]

    hits = [(uniq_ids[i], cand_ctx[j], float(scores[j])) for j, i in enumerate(idx)]
    return hits


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from peft import PeftModel

# Reuse your already-loaded tokenizer/model if present
try:
    tokenizer
    model
except NameError:
    MODEL_ID    = "GanjinZero/biobart-base"             # or "GanjinZero/biobart-base"
    ADAPTER_DIR = "/content/biobart_5k_cpu_lora"   # or merged dir
    tokenizer   = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    base        = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
    # try attach adapters; fall back to base if not found
    try:
        model = PeftModel.from_pretrained(base, ADAPTER_DIR).eval()
        print("Loaded LoRA adapters.")
    except Exception:
        model = base.eval()
        print("Using base model.")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

INSTR_GROUNDED = (
    "Using only the CONTEXT below, answer the QUESTION fully and concisely. "
    "If the CONTEXT does not contain the answer, reply exactly: Insufficient context."
)
INSTR_FALLBACK = (
    "Answer the medical QUESTION concisely and cautiously using general medical knowledge. "
    "If uncertain, say: 'Insufficient information to answer reliably.'"
)

@torch.no_grad()
def answer_query(
    question: str,
    *,
    k_bm25=800, k_dense=120, k_cross=24, k_final=6,
    sim_threshold=0.25,       # accept grounding if top score >= threshold
    allow_fallback=True,      # use model knowledge if no good ground
    style="medium",           # "short" | "medium" | "long"
    add_evidence=False,       # append one evidence sentence when grounded
    return_meta=False
):
    # Retrieve
    hits = retrieve_candidates(question, k_bm25=k_bm25, k_dense=k_dense, k_cross=k_cross, k_final=k_final)
    use_grounded = bool(hits) and hits[0][2] >= sim_threshold

    # Build prompt
    if use_grounded:
        ctx = "\n\n---\n\n".join(h[1] for h in hits[:k_final])
        prompt = f"{INSTR_GROUNDED}\n\nCONTEXT:\n{ctx}\n\nQUESTION:\n{_norm(question)}\n\nAnswer:"
    else:
        if not allow_fallback:
            ans = "Insufficient context"
            return (ans, {"grounded": False, "hits": hits}) if return_meta else ans
        prompt = f"{INSTR_FALLBACK}\n\nQUESTION:\n{_norm(question)}\n\nAnswer:"

    # Decode length presets
    style = style.lower()
    if style == "short":
        max_new, min_new, beams, lp = 48, 6, 4, 0.7
    elif style == "long":
        max_new, min_new, beams, lp = 160, 12, 6, 0.6
    else:
        max_new, min_new, beams, lp = 96, 8, 6, 0.65

    # Generate
    enc = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
    enc = {k: v.to(device) for k, v in enc.items()}
    out = model.generate(
        **enc,
        max_new_tokens=max_new,
        min_new_tokens=min_new,
        do_sample=False,
        num_beams=beams,
        length_penalty=lp,
        no_repeat_ngram_size=3,
        eos_token_id=tokenizer.eos_token_id
    )
    text = tokenizer.decode(out[0], skip_special_tokens=True)
    ans  = text.split("Answer:")[-1].strip()

    # Evidence
    def _best_sentence(ctx: str, q: str) -> str:
        sents = re.split(r'(?<=[\.\!\?])\s+', ctx.strip())
        if not sents: return ctx.strip()[:200]
        q_words = set(re.findall(r"\w+", q.lower()))
        def score(s):
            s_words = set(re.findall(r"\w+", s.lower()))
            return len(q_words & s_words)
        return max(sents, key=score).strip()

    if add_evidence and use_grounded:
        ev = _best_sentence(ctx, question)
        if ev and ev not in ans:
            ans = f"{ans}\n\nEvidence: {ev}"

    meta = {"grounded": use_grounded, "hits": hits}
    return (ans, meta) if return_meta else ans


In [None]:
q1 = "Do 5 mg Zolmitriptan tablets contain gluten?"
a1, m1 = answer_query(q1, style="short", add_evidence=True, return_meta=True)
print("Grounded:", m1["grounded"]); print(a1)

q2 = "What are the neurological symptoms of hyperthyroidism?"
a2, m2 = answer_query(q2, style="medium", return_meta=True)
print("\nGrounded:", m2["grounded"]); print(a2)

q3 = "Explain the difference between type 1 and type 2 diabetes."
print("\n", answer_query(q3, style="long"))


Grounded: True
Zolmitriptan tablets

Evidence: Zolmitriptan tablets are available as 2.5 mg (yellow and functionally-scored) and mg (pink, not scored) film coated tablets for oral administration.

Grounded: True
hyperthyroidismCONTEXT:

 Diabetes is disease in which Glucose comes from the foods you eat


In [None]:
import shutil
from google.colab import files

ADAPTER_DIR = "/content/biobart_5k_cpu_lora"  # change if your folder is different
shutil.make_archive("biobart_adapter", 'zip', ADAPTER_DIR)
files.download("biobart_adapter.zip")


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>