In [1]:
# Core libs for generation
%pip install -U transformers accelerate bitsandbytes datasets --quiet
# Metrics
%pip install --no-deps -U rouge-score==0.1.2 nltk==3.9.1 tqdm==4.66.5
%pip install --no-deps -U pytorch_pretrained_bert==0.6.2 moverscore==1.0.3


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.0/44.0 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m91.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m375.8/375.8 kB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m506.3/506.3 kB[0m [31m24.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.3/564.3 kB[0m [31m28.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m44.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m84.9 MB/s[0m eta [36

In [14]:
%pip install -U --no-cache-dir \
  "transformers==4.44.2" "accelerate==0.34.2" "bitsandbytes==0.43.1" \
  "rouge-score==0.1.2" "nltk==3.9.1" "tqdm==4.66.5" "safetensors>=0.4.3"


Collecting transformers==4.44.2
  Downloading transformers-4.44.2-py3-none-any.whl.metadata (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes==0.43.1
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl.metadata (2.2 kB)
Collecting safetensors>=0.4.3
  Downloading safetensors-0.6.2-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.1 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.44.2)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.44.2-py3-none-any.whl (9.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.5/9.5 MB[0m [31m79.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading accelerate-0.34.2-py3-none-any.whl (324

In [15]:
import os, gc, json, random, time
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch
import sys, os
sys.path.append("/kaggle/input/emnlp19-moverscore")  # path that contains the 'moverscore_v2' folder
from moverscore_v2 import get_idf_dict, word_mover_score


In [16]:
# ====== SPEED PROFILE ======
SPEED_PROFILE = os.getenv("SPEED_PROFILE", "balanced").lower()
def preset(profile: str):
    if profile == "speed":
        return dict(MAX_INPUT_TOKENS=512, MAX_NEW_TOKENS=120, BATCH_TOKENS_BUDGET=16000, MAX_BATCH_SAMPLES=3, SAFETY=0.9)
    if profile == "safe":
        return dict(MAX_INPUT_TOKENS=768, MAX_NEW_TOKENS=128, BATCH_TOKENS_BUDGET=11000, MAX_BATCH_SAMPLES=2, SAFETY=0.85)
    # balanced
    return dict(MAX_INPUT_TOKENS=768, MAX_NEW_TOKENS=128, BATCH_TOKENS_BUDGET=14000, MAX_BATCH_SAMPLES=2, SAFETY=0.9)

_P = preset(SPEED_PROFILE)

# ====== CFG ======
class CFG:
    LOCAL_MODEL_DIR = os.getenv("LLAMA_MODEL_DIR", "/kaggle/input/llama-3.1/transformers/8b-instruct/2")
    MEDRED_CSV      = os.getenv("MEDRED_CSV", "/kaggle/input/medred/medredqa_test.csv")
    OUT_DIR = Path(os.getenv("OUT_DIR", "/kaggle/working")); OUT_DIR.mkdir(parents=True, exist_ok=True)

    # Generation (greedy)
    MAX_NEW_TOKENS = int(os.getenv("MAX_NEW_TOKENS", 80))  
    DO_SAMPLE = False
    MAX_BATCH_SAMPLES = int(os.getenv("MAX_BATCH_SAMPLES", 3))

    # Context / batching
    MAX_INPUT_TOKENS    = int(os.getenv("MAX_INPUT_TOKENS", 512))   # ↓ from 1024/768
    BATCH_TOKENS_BUDGET = int(os.getenv("BATCH_TOKENS_BUDGET", 18000))
    SAFETY = float(os.getenv("BATCH_SAFETY", 0.9))

    SYSTEM_PROMPT = "You are a careful medical assistant. Answer clearly and concisely for lay readers."
    N_ROWS = int(os.getenv("N_ROWS", 1000))           
    SAMPLE_STRATEGY = os.getenv("SAMPLE_STRATEGY", "random")  
    SHARD_TOTAL = int(os.getenv("SHARD_TOTAL", 1))    
    SHARD_INDEX = int(os.getenv("SHARD_INDEX", 0))    


In [4]:
# ====== perf knobs ======
SEED = 42
os.environ.setdefault("PYTHONHASHSEED", str(SEED))
random.seed(SEED); np.random.seed(SEED)
torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True
try:
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
except Exception:
    pass
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
os.environ["TRANSFORMERS_NO_TF"] = "1"; os.environ["JAX_PLATFORMS"]="cpu"; os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
os.environ["TOKENIZERS_PARALLELISM"]="true"

from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig


# 4-bit quant
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16)
tok = AutoTokenizer.from_pretrained(CFG.LOCAL_MODEL_DIR, local_files_only=True, use_fast=True)
if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "left"

# prefer eager; flash if available
model = AutoModelForCausalLM.from_pretrained(
    CFG.LOCAL_MODEL_DIR,
    device_map={"": "cuda:0"},
    low_cpu_mem_usage=True,
    quantization_config=bnb,
    torch_dtype=torch.float16,
    attn_implementation="eager",
    local_files_only=True,
)
model.eval()
model.config.pad_token_id = tok.pad_token_id


`torch_dtype` is deprecated! Use `dtype` instead!
E0000 00:00:1761215457.977483      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1761215458.039873      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [7]:
# ========= batching =========
def make_batches_by_padded_cost(prompt_lens, budget_tokens, max_new, max_batch_samples, safety=0.9):
    batches, cur = [], []
    cur_max = 0
    eff = int(budget_tokens * safety)
    for i, L in enumerate(prompt_lens):
        new_max = max(cur_max, L)
        bs_next = len(cur) + 1
        cost = bs_next * (new_max + max_new)
        if cur and (cost > eff or len(cur) >= max_batch_samples):
            batches.append(cur); cur = [i]; cur_max = L
        else:
            cur.append(i); cur_max = new_max
    if cur: batches.append(cur)
    return batches

# ========= build prompts =========
def build_prompt(q: str) -> str:
    msgs = [{"role":"system","content":CFG.SYSTEM_PROMPT},{"role":"user","content":q}]
    try:
        return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    except Exception:
        return (
            f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{CFG.SYSTEM_PROMPT}\n"
            f"<|eot_id|><|start_header_id|>user<|end_header_id|>\n{q}\n"
            f"<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
        )

# ========= encode all once (CPU), track lens =========
# df = pd.read_csv(CFG.MEDRED_CSV)
# df = pd.DataFrame({
#     "id": df["Unnamed: 0"].astype(str),
#     "query": (df["Title"].fillna("") + "\n\n" + df["Body"].fillna("")).str.strip(),
#     "reference": df["Response"].fillna("").astype(str).str.strip()
# })
# df = df[df["query"].str.len() > 0].reset_index(drop=True)

# prompts = [build_prompt(q) for q in df["query"].tolist()]

df_raw = pd.read_csv(CFG.MEDRED_CSV)
df = pd.DataFrame({
    "id": df_raw.iloc[:, 0].astype(str),
    "query": (df_raw["Title"].fillna("") + "\n\n" + df_raw["Body"].fillna("")).str.strip(),
    "reference": df_raw["Response"].fillna("").astype(str).str.strip()
})
df = df[df["query"].str.len() > 0].reset_index(drop=True)

if CFG.N_ROWS and CFG.N_ROWS > 0 and CFG.N_ROWS < len(df):
    if CFG.SAMPLE_STRATEGY.lower() == "head":
        df = df.iloc[:CFG.N_ROWS].copy()
    else:
        # random sample for better coverage; reproducible via SEED
        df = df.sample(n=CFG.N_ROWS, random_state=SEED).sort_index().copy()
        
prompts = [build_prompt(q) for q in df["query"].tolist()]
assert len(prompts) == len(df), "prompts/df length mismatch"
# IMPORTANT: get ragged LISTS, not tensors (avoids ValueError)
enc = tok(
    prompts,
    add_special_tokens=True,
    padding=False,                 # no global padding here
    truncation=True,
    max_length=CFG.MAX_INPUT_TOKENS,
    return_tensors=None            # <- changed from "pt"
)

input_ids_list = enc["input_ids"]         # List[List[int]]
lens = [len(x) for x in input_ids_list]   # lengths per sample

# sort by length for tighter padding
order = np.argsort(lens).tolist()
ids_sorted  = [df["id"].iloc[i]        for i in order]
ref_sorted  = [df["reference"].iloc[i] for i in order]
iids_sorted = [input_ids_list[i]       for i in order]
lens_sorted = [lens[i]                 for i in order]

batches = make_batches_by_padded_cost(
    lens_sorted,
    CFG.BATCH_TOKENS_BUDGET,
    CFG.MAX_NEW_TOKENS,
    CFG.MAX_BATCH_SAMPLES,
    safety=CFG.SAFETY
)

# multi-EOS for early stop (eot/eos and a newline fence)
eos_token_ids = [tok.eos_token_id]
try:
    eot_id = tok.convert_tokens_to_ids("<|eot_id|>")
    if isinstance(eot_id, int) and eot_id != tok.eos_token_id:
        eos_token_ids.append(eot_id)
except Exception:
    pass
for stop_str in ["\n\n", "\n###", "\nUser:", "\nAssistant:"]:
    ids = tok(stop_str, add_special_tokens=False).input_ids
    if ids:
        eos_token_ids.append(ids[-1])


def left_pad_collate(iid_list, device):
    maxL = max(len(x) for x in iid_list)
    pad_id = tok.pad_token_id
    input_ids = torch.full((len(iid_list), maxL), pad_id, dtype=torch.long)
    attn_mask = torch.zeros((len(iid_list), maxL), dtype=torch.bool)
    for r, ids_r in enumerate(iid_list):
        L = len(ids_r)
        input_ids[r, -L:] = torch.as_tensor(ids_r, dtype=torch.long)
        attn_mask[r,  -L:] = True
    return input_ids.to(device, non_blocking=True), attn_mask.to(device, non_blocking=True)


In [8]:
# ========= generation (lean) =========

pred_rows = []
with torch.inference_mode():
    for b_ix, idxs in enumerate(tqdm(batches, desc=f"Generating[{SPEED_PROFILE}]")):
        batch_iids = [iids_sorted[i] for i in idxs]
        input_ids, attn_mask = left_pad_collate(batch_iids, model.device)

        out = model.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            max_new_tokens=CFG.MAX_NEW_TOKENS,  # paper: 150; using 96–128/150 for speed
            do_sample=False,                      # greedy; sampling didn’t help MoverScore in paper
            eos_token_id=eos_token_ids,
            pad_token_id=tok.pad_token_id,
            use_cache=True,
        )

        new_tokens = out[:, input_ids.shape[1]:]
        texts = tok.batch_decode(new_tokens, skip_special_tokens=True)

        for j, txt in zip(idxs, texts):
            pred_rows.append({"id": str(ids_sorted[j]), "prediction": txt.strip()})

        del input_ids, attn_mask, out, new_tokens, texts
        torch.cuda.synchronize(); torch.cuda.empty_cache(); gc.collect()

preds = pd.DataFrame(pred_rows)
preds.to_csv(CFG.OUT_DIR / "predictions_sub.csv", index=False)


Generating[balanced]:   0%|          | 0/334 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Generating[balanced]: 100%|██████████| 334/334 [1:36:39<00:00, 17.36s/it]


In [17]:
# ==== Compatibility shim for moverscore_v2 ====
def _patch_tokenizer_max_len():
    try:
        import transformers
        for name in ("PreTrainedTokenizer", "PreTrainedTokenizerFast"):
            cls = getattr(transformers, name, None)
            if cls is not None and not hasattr(cls, "max_len"):
                cls.max_len = property(lambda self: getattr(self, "model_max_length", 512))
    except Exception:
        pass
_patch_tokenizer_max_len()

# ==== END shim ====

# ========= evaluation: fast ROUGE + faster MoverScore =========
def normalize(s): return " ".join((s or "").split())
eval_df = df[["id","reference"]].merge(preds, on="id", how="left")
eval_df["reference"]  = eval_df["reference"].map(normalize)
eval_df["prediction"] = eval_df["prediction"].fillna("").map(normalize)

# ROUGE-1 (parallel)
from rouge_score import rouge_scorer
from multiprocessing import Pool, cpu_count
def _init_scorer():
    global _SCORER
    _SCORER = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
def _r1_pair(args):
    ref, hyp = args
    return _SCORER.score(ref, hyp)["rouge1"].fmeasure
pairs = list(zip(eval_df["reference"], eval_df["prediction"]))
with Pool(processes=max(1, cpu_count()-1), initializer=_init_scorer) as pool:
    r1 = np.array(list(tqdm(pool.imap(_r1_pair, pairs), total=len(eval_df), desc="ROUGE-1(F1)")), dtype=np.float32)

# MoverScore v2 (now works)
DEV_SKIP_MOVERSCORE = bool(int(os.getenv("DEV_SKIP_MOVERSCORE", "0")))
try:
    if DEV_SKIP_MOVERSCORE:
        raise RuntimeError("DEV_SKIP_MOVERSCORE=1")
    from moverscore_v2 import get_idf_dict, word_mover_score

    # OPTIONAL: only score a subset to speed up (comment this block to score all)
    MS_MAX = int(os.getenv("MS_MAX", "0"))  # e.g., set 200 to score first 200
    eval_df_ms = eval_df if MS_MAX <= 0 or MS_MAX >= len(eval_df) else eval_df.iloc[:MS_MAX].copy()

    refs = eval_df_ms["reference"].tolist(); hyps = eval_df_ms["prediction"].tolist()
    idf_ref = get_idf_dict(refs); idf_hyp = get_idf_dict(hyps)  # compute once
    ms_list = word_mover_score(
        refs, hyps, idf_ref, idf_hyp,
        stop_words=None, n_gram=1, remove_subwords=True, batch_size=64
    )
    ms = np.array([float(x) for x in ms_list], dtype=np.float32)

    # If subset scored, broadcast back with NaNs elsewhere
    if len(eval_df_ms) != len(eval_df):
        ms_full = np.full(len(eval_df), np.nan, dtype=np.float32)
        ms_full[:len(ms)] = ms
        ms = ms_full
    moverscore_ok = True
except Exception as e:
    print("WARNING: MoverScore failed:", e)
    ms = np.full(len(eval_df), np.nan, dtype=np.float32)
    moverscore_ok = False

eval_df["rouge1_f1"] = r1
eval_df["moverscore"] = ms
summary = {
    "n": int(len(eval_df)),
    "rouge1_f1_mean": float(np.nanmean(r1)),
    "rouge1_f1_std":  float(np.nanstd(r1)),
    "moverscore_mean": float(np.nanmean(ms)),
    "moverscore_std":  float(np.nanstd(ms)),
    "moverscore_ok": moverscore_ok,
    "speed_profile": SPEED_PROFILE,
    "ms_max": int(os.getenv("MS_MAX", "0")),
    "dev_skip_moverscore": DEV_SKIP_MOVERSCORE,
}
with open(CFG.OUT_DIR/"summary.json", "w") as f: json.dump(summary, f, indent=2)
eval_df.to_csv(CFG.OUT_DIR/"per_example_scores.csv", index=False)
print("[DONE]", json.dumps(summary, indent=2))


  self.pid = os.fork()
ROUGE-1(F1): 100%|██████████| 1002/1002 [00:00<00:00, 1450.54it/s]


[DONE] {
  "n": 1002,
  "rouge1_f1_mean": 0.16458486020565033,
  "rouge1_f1_std": 0.08978775888681412,
  "moverscore_mean": NaN,
  "moverscore_std": NaN,
  "moverscore_ok": false,
  "speed_profile": "balanced",
  "ms_max": 0,
  "dev_skip_moverscore": false
}


  "moverscore_mean": float(np.nanmean(ms)),
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
