# 04 — Candidate Generation (K per article)

**Goal:** For each article, generate **K diverse summary candidates** and persist them with their **decoding params, lengths, and log‑probs**.  
We’ll support: **diverse beam search**, **top‑k sampling**, and **top‑p (nucleus) sampling**.

Format: **What/Why → Code (commented) → How to read results**.


## What / Why
**What:** Imports, consistent config, and load the trained baseline checkpoint.  
**Why:** We need the **same dataset split** and tokenization limits as training, then load our best baseline checkpoint to generate candidates.


In [None]:

import sys, subprocess
def pip_install(pkgs):
    subprocess.run([sys.executable, "-m", "pip", "install", "-q"] + pkgs, check=True)

pip_install([
    "transformers==4.41.2",
    "datasets==2.19.1",
    "evaluate==0.4.2",
    "accelerate==0.30.1",
    "sentencepiece==0.1.99",
])

import os, json, math, random
from pathlib import Path
import numpy as np
import torch
from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForSeq2SeqLM, set_seed)

with open("configs/baseline.json", "r") as f:
    cfg = json.load(f)

SEED = cfg.get("seed", 42)
random.seed(SEED); np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)
set_seed(SEED)

CKPT_DIR = Path(cfg["training"]["output_dir"])
assert CKPT_DIR.exists(), f"Checkpoint dir not found: {CKPT_DIR}. Run notebook 03 first."

raw = load_dataset(cfg["dataset"]["hf_id"], cfg["dataset"]["config"])
tok = AutoTokenizer.from_pretrained(cfg["model"]["hf_id"], use_fast=True)
val = raw[cfg["dataset"]["split_val"]]

print(f"Using checkpoint at: {CKPT_DIR.resolve()}")
print(val)


**How to read results:**  
If the checkpoint path exists and the validation split prints with length, you’re set. Use validation for dev; later you can run on the full test.


## What / Why
**What:** Define **helper functions** to (1) tokenize a batch of sources, (2) generate multiple candidates per example with *different strategies*, and (3) compute approximate **log‑probabilities** for each output.  
**Why:** We need *comparable ingredients* for reranking: the text, its length, and its model log‑probability.


In [None]:

from typing import List, Dict, Any
from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForSeq2SeqLM.from_pretrained(str(CKPT_DIR)).to(device)
model.eval()

SRC_COL = cfg["text_fields"]["source"]
REF_COL = cfg["text_fields"]["summary"]
MAX_SRC = cfg["tokenization"]["max_source_len"]
MAX_TGT = cfg["tokenization"]["max_target_len"]

def tokenize_sources(texts: List[str]):
    batch = tok(
        texts,
        max_length=MAX_SRC,
        truncation=True,
        padding=True,
        return_tensors="pt",
    )
    return {k: v.to(device) for k, v in batch.items()}

def sequence_logprob(output_ids: torch.Tensor, scores: List[torch.Tensor]) -> float:
    # Convert scores to log-prob distribution and accumulate per generated token
    logprobs = [torch.log_softmax(s, dim=-1) for s in scores]
    gen_token_ids = output_ids[-len(scores):]
    lp = 0.0
    for t, tok_id in enumerate(gen_token_ids):
        lp += float(logprobs[t][0, tok_id.item()].detach().cpu())
    return lp

def generate_candidates(inputs: Dict[str, torch.Tensor], strategy: str, num_return_sequences: int, common_kwargs: Dict[str, Any]) -> Dict[str, Any]:
    if strategy == "diverse_beam":
        kwargs = dict(num_beams=max(4, num_return_sequences),
                      num_beam_groups=max(2, num_return_sequences//2),
                      diversity_penalty=1.0,
                      do_sample=False)
    elif strategy == "topk_sampling":
        kwargs = dict(do_sample=True, top_k=50, temperature=1.0, num_beams=1)
    elif strategy == "topp_sampling":
        kwargs = dict(do_sample=True, top_p=0.92, temperature=1.0, num_beams=1)
    else:
        raise ValueError(f"Unknown strategy: {strategy}")

    out = model.generate(
        **inputs,
        return_dict_in_generate=True,
        output_scores=True,
        max_length=MAX_TGT,
        num_return_sequences=num_return_sequences,
        **kwargs,
        **common_kwargs
    )
    decoded = tok.batch_decode(out.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)
    seqs = out.sequences
    scores = out.scores

    per_seq_logp = []
    for i in range(seqs.shape[0]):
        per_time = [s[i:i+1, :] for s in scores]
        lp = sequence_logprob(seqs[i], per_time)
        per_seq_logp.append(lp)

    return {"texts": decoded, "logprobs": per_seq_logp}


**How to read results:**  
Above, we defined *three strategies* and a function that returns generated texts and an **approximate sequence log‑probability** using the per‑step generation scores. This log‑prob is a good baseline for “best‑of‑K by the model.”


## What / Why
**What:** Run generation over the validation set (or a slice) and persist **JSONL** with per‑candidate metadata.  
**Why:** Reranking needs a persistent, analyzable artifact: `{id, source, reference, candidate, strategy, params, length, logprob}`.


In [None]:

from itertools import islice
from pathlib import Path

GEN_CFG = {
    "split": "validation",
    "max_items": 500,
    "K_per_strategy": 4,
    "common_kwargs": {"length_penalty": 1.0, "no_repeat_ngram_size": 3},
    "strategies": ["diverse_beam", "topk_sampling", "topp_sampling"],
    "output_jsonl": "outputs/candidates_val.jsonl",
}

Path("outputs").mkdir(parents=True, exist_ok=True)

ds = raw[GEN_CFG["split"]]
N = len(ds) if GEN_CFG["max_items"] is None else min(GEN_CFG["max_items"], len(ds))

with open(GEN_CFG["output_jsonl"], "w", encoding="utf-8") as f:
    for idx in tqdm(range(N), total=N):
        ex = ds[idx]
        src = ex[SRC_COL]
        ref = ex[REF_COL]
        batch_inputs = tokenize_sources([src])
        for strat in GEN_CFG["strategies"]:
            results = generate_candidates(
                batch_inputs,
                strategy=strat,
                num_return_sequences=GEN_CFG["K_per_strategy"],
                common_kwargs=GEN_CFG["common_kwargs"],
            )
            for txt, lp in zip(results["texts"], results["logprobs"]):
                rec = {
                    "id": idx,
                    "source": src,
                    "reference": ref,
                    "candidate": txt,
                    "strategy": strat,
                    "params": {
                        **GEN_CFG["common_kwargs"],
                        "strategy": strat,
                        "K_per_strategy": GEN_CFG["K_per_strategy"],
                    },
                    "length": len(tok.encode(txt, add_special_tokens=False)),
                    "logprob": lp,
                }
                f.write(json.dumps(rec) + "\n")

print(f"Wrote candidates to: {GEN_CFG['output_jsonl']}")


**How to read results:**  
You’ll get a growing `outputs/candidates_val.jsonl`. Each line captures one candidate with its **strategy, params, length, and log‑prob**. This file is the input to the verification/reranking step.
