In [7]:
from pathlib import Path
import json, math, re, random
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch
import nltk
nltk.download("punkt")


PROJECT = Path("..")  # adjust if notebook sits elsewhere
GEN_DIR = PROJECT / "data" / "processed" / "generative"
SFT_DIR = PROJECT / "outputs" / "sft" / "tinyllama_sft_pubmedqa_cpu"

BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
USE_CPU = True  # you’re on AMD/Windows; keep True
SEED = 42
random.seed(SEED)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\santi\AppData\Roaming\nltk_data...
[nltk_data]   Unzipping tokenizers\punkt.zip.


In [2]:
ds = load_dataset("json", data_files={
    "train": str(GEN_DIR / "train.jsonl"),
    "validation": str(GEN_DIR / "val.jsonl")
})
# Use, say, 200 train + 50 val for reward prep
sub_train = ds["train"].shuffle(seed=SEED).select(range(min(200, len(ds["train"]))))
sub_val   = ds["validation"].shuffle(seed=SEED).select(range(min(50, len(ds["validation"]))))
len(sub_train), len(sub_val)

Generating train split: 799 examples [00:00, 43834.68 examples/s]
Generating validation split: 99 examples [00:00, 14120.31 examples/s]


(200, 50)

In [3]:
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

base = AutoModelForCausalLM.from_pretrained(BASE_MODEL)
model = PeftModel.from_pretrained(base, str(SFT_DIR))
model.eval()
device = torch.device("cpu") if USE_CPU else model.device
model.to(device)

The 8-bit optimizer is not available on your device, only available on CUDA for now.


PeftModel(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 2048)
        (layers): ModuleList(
          (0-21): 22 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=2048, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2048, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=2048, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
            

In [4]:
SYSTEM_PROMPT = "You are a careful medical QA assistant. Only answer using the provided context."
def make_prompt(q, c):
    # Mirror SFT formatting (system/user/assistant)
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": f"Question: {q}\n\nContext:\n{c}\n\nAnswer succinctly based only on the context."},
    ]
    return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

In [5]:
def batched_generate(batch, max_new_tokens=128, temperature=0.2, top_p=0.9):
    prompts = [make_prompt(q, c) for q, c in zip(batch["question"], batch["context"])]
    inputs = tok(prompts, return_tensors="pt", padding=True, truncation=True, max_length=768).to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tok.eos_token_id
        )
    texts = tok.batch_decode(out, skip_special_tokens=True)
    # Extract only the assistant continuation after the last prompt
    # (simple split that works with TinyLlama chat template)
    replies = []
    for full, prompt in zip(texts, prompts):
        replies.append(full[len(prompt):].strip())
    return {"model_answer": replies}

sub_train_gen = sub_train.map(batched_generate, batched=True, batch_size=4)
sub_val_gen   = sub_val.map(batched_generate,   batched=True, batch_size=4)
len(sub_train_gen), len(sub_val_gen)


Map:   0%|          | 0/200 [00:00<?, ? examples/s]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Map:   2%|▏         | 4/200 [00:55<45:33, 13.95s/ examples]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Map:   4%|▍         | 8/200 [02:05<51:07, 15.98s/ examples]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Map:   6%|▌         | 12/200 [03:16<52:35, 16.78s/ examples]A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.
Map:   8%|▊         | 16/200 [04:31<53:47, 17.54s/ examples]A decoder-only arch

(200, 50)

In [8]:
import evaluate
rouge = evaluate.load("rouge")  # comes with 'evaluate' package

STOP_PHRASES = [
    "as an ai", "cannot provide medical advice", "consult a doctor",
]
HALLUCINATION_FLAGS = [
    "according to", "source:", "wikipedia", "cdc says", "who says",
]

def keyword_tokens(text):
    toks = re.findall(r"[a-z0-9]+", (text or "").lower())
    stop = set(["the","a","an","and","or","is","are","of","to","in","on","with","for","by","as","that","this","it","be"])
    return [t for t in toks if t not in stop and len(t) > 2]

def score_example(model_answer, gold_answer, context):
    # 1) ROUGE-L
    r = rouge.compute(predictions=[model_answer], references=[gold_answer])["rougeL"]

    # 2) Coverage: fraction of unique gold tokens present in model answer
    gold_keys = set(keyword_tokens(gold_answer))
    pred_keys = set(keyword_tokens(model_answer))
    cov = (len(gold_keys & pred_keys) / max(1,len(gold_keys)))

    # 3) Safety / hallucination heuristics
    lower = model_answer.lower()
    safety_pen = 0.0
    if any(p in lower for p in HALLUCINATION_FLAGS):
        safety_pen += 0.1
    # If model uses general disclaimers excessively, small penalty (we want concise, context-grounded)
    if sum(p in lower for p in STOP_PHRASES) >= 2:
        safety_pen += 0.05

    # 4) Length regularizer (prefer answers ~25–180 tokens)
    n_tok = len(keyword_tokens(model_answer))
    if n_tok < 12:
        len_pen = 0.05
    elif n_tok > 220:
        len_pen = 0.05
    else:
        len_pen = 0.0

    # Final reward in [0, 1]-ish
    reward = max(0.0, r*0.6 + cov*0.4 - safety_pen - len_pen)
    return float(reward), {"rougeL": r, "coverage": cov, "safety_pen": safety_pen, "len_pen": len_pen}

def add_rewards(batch):
    rewards, rmeta = [], []
    for ma, ga, ctx in zip(batch["model_answer"], batch["answer"], batch["context"]):
        rw, meta = score_example(ma, ga, ctx)
        rewards.append(rw)
        rmeta.append(meta)
    return {"reward": rewards, "reward_meta": rmeta}

sub_train_rw = sub_train_gen.map(add_rewards, batched=True, batch_size=32)
sub_val_rw   = sub_val_gen.map(add_rewards,   batched=True, batch_size=32)
sub_train_rw[0]


Map: 100%|██████████| 200/200 [00:34<00:00,  5.78 examples/s]
Map: 100%|██████████| 50/50 [00:05<00:00,  9.04 examples/s]


{'id': '26237424',
 'question': 'Does patient-prosthesis mismatch after aortic valve replacement affect survival and quality of life in elderly patients?',
 'context': 'To evaluate the impact of patient-prosthesis mismatch (PPM) on survival, functional status, and quality of life (QoL) after aortic valve replacement (AVR) with small prosthesis size in elderly patients.\n\nBetween January 2005 and December 2013, 152 patients with pure aortic stenosis, aged at least 75 years, underwent AVR, with a 19 or 21 mm prosthetic heart valve. PPM was defined as an indexed effective orifice area less than 0.85 cm/m. Median age was 82 years (range 75-93 years). Mean follow-up was 56 months (range 1-82 months) and was 98% complete. Late survival rate, New York Heart Association functional class, and QoL (RAND SF-36) were assessed.\n\nOverall, PPM was found in 78 patients (53.8%). Among them, 42 patients (29%) had an indexed effective orifice area less than 0.75 cm/m and 17 less than 0.65 cm/m (11.7%)

In [9]:
OUT = PROJECT / "data" / "rewards"
OUT.mkdir(parents=True, exist_ok=True)

def dump_jsonl(path, rows):
    with path.open("w", encoding="utf-8") as f:
        for r in rows:
            f.write(json.dumps({k: r[k] for k in ["question","context","answer","model_answer","reward"]}, ensure_ascii=False) + "\n")

dump_jsonl(OUT / "train_scored.jsonl", [sub_train_rw[i] for i in range(len(sub_train_rw))])
dump_jsonl(OUT / "val_scored.jsonl",   [sub_val_rw[i]   for i in range(len(sub_val_rw))])
(str(OUT / "train_scored.jsonl"), str(OUT / "val_scored.jsonl"))

('..\\data\\rewards\\train_scored.jsonl',
 '..\\data\\rewards\\val_scored.jsonl')