In [None]:
# Clean installs with compatible pins to avoid resolver conflicts
!pip -q install -U pip setuptools wheel

# Core libs for the project
!pip -q install "datasets>=2.20" "transformers>=4.43" "accelerate>=0.33" \
                "trl>=0.9.6" "peft>=0.12" evaluate pandas numpy tqdm regex

# Quantization for QLoRA
!pip -q install bitsandbytes

# Professor extras
!pip -q install wandb llama-recipes

# --- Compatibility pins to satisfy preinstalled packages in Colab ---
# google-adk -> needs PyYAML >=6.0.2,<7
# Flask/Werkzeug want MarkupSafe >=2.1.1
# pymc wants rich >=13.7.1
!pip -q install "pyyaml>=6.0.2,<7" "markupsafe>=2.1.5" "rich>=13.7.1"

# Quick sanity print so we can see resolved versions
import pkgutil, importlib
for mod in ["yaml", "markupsafe", "rich"]:
    m = importlib.import_module(mod)
    print(mod, getattr(m, "__version__", "ok"))


In [None]:
# Clean out unrelated preinstalled packages that pin conflicting versions.
# These are not needed for this assignment.
!pip -q uninstall -y llama-cookbook semgrep bigframes pymc google-adk || true

# Install versions that work great with our stack (and are widely compatible)
!pip -q install "pyyaml>=6.0.2,<7" "markupsafe>=2.1.5" "rich==13.7.1"

# Sanity check
import importlib
for mod in ["yaml", "markupsafe", "rich"]:
    m = importlib.import_module(mod)
    print(mod, getattr(m, "__version__", "ok"))


In [None]:
# [MY ADDITION] Only needed if the model is gated and asks for a token
# from huggingface_hub import login
# login("paste you HF token here")


In [None]:
# [MY ADDITION] Common imports + dirs
import os, re, json, time, random, gc
import numpy as np
import torch

from dataclasses import dataclass
from datetime import datetime
from tqdm.auto import tqdm
from datasets import load_dataset, Dataset

from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    BitsAndBytesConfig,
    default_data_collator, DataCollatorForSeq2Seq
)

import torch.optim as optim
from torch.optim.lr_scheduler import StepLR

# Try MemoryTrace; define a no-op fallback if not available
try:
    from llama_recipes.utils.memory_utils import MemoryTrace
except Exception:
    class MemoryTrace:
        def __enter__(self): return self
        def __exit__(self, *args): pass
        def print_stats(self):
            try:
                mem_alloc = torch.cuda.max_memory_allocated()/(1024**3)
                mem_reserved = torch.cuda.max_memory_reserved()/(1024**3)
                print(f"Max CUDA memory allocated was {mem_alloc:.1f} GB")
                print(f"Max CUDA memory reserved was {mem_reserved:.1f} GB")
            except: pass

# Dirs
OUT_DIR  = "outputs";     os.makedirs(OUT_DIR, exist_ok=True)
DATA_DIR = "data";        os.makedirs(DATA_DIR, exist_ok=True)
CKPT_DIR = "checkpoints"; os.makedirs(CKPT_DIR, exist_ok=True)

# Repro
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# Minor speedup
torch.backends.cuda.matmul.allow_tf32 = True
print("CUDA?", torch.cuda.is_available())


In [None]:
# === Cell 3.9 — BitsAndBytes fix & sanity checks ===
!pip -q install -U bitsandbytes accelerate

import torch, importlib, os
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

try:
    bnb = importlib.import_module("bitsandbytes")
    print("bitsandbytes version:", getattr(bnb, "__version__", "unknown"))
except Exception as e:
    print("bitsandbytes import failed:", e)

# Helps with CUDA memory fragmentation when (re)loading models
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


In [None]:
# === Cell 4 (UPDATED) — Load Llama-3.2-3B-Instruct with QLoRA if available ===
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers.utils import is_bitsandbytes_available
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
import torch, gc, os

MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

gc.collect(); torch.cuda.empty_cache()
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

use_cuda = torch.cuda.is_available()
use_4bit = is_bitsandbytes_available(check_library_only=True) and use_cuda

print(f"CUDA available: {use_cuda}")
print(f"bitsandbytes available: {is_bitsandbytes_available(check_library_only=True)}")
print("Attempting", "4-bit QLoRA load" if use_4bit else "fallback (no 4-bit)")

if use_4bit:
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16 if use_cuda else torch.float32,
    )
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_cfg,
        device_map="auto",
    )
else:
    # Fallback: non-quantized (heavier). Keep bfloat16/float16 on GPU if possible.
    dtype = torch.bfloat16 if use_cuda else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=dtype,
        device_map="auto",
    )

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token_id is None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare for training
model.config.use_cache = False

if use_4bit:
    # QLoRA path
    model = prepare_model_for_kbit_training(model)
    model.gradient_checkpointing_enable()

    lora_cfg = LoraConfig(
        task_type="CAUSAL_LM",
        r=16, lora_alpha=32, lora_dropout=0.05,
        target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
    )
    model = get_peft_model(model, lora_cfg)
    print("[QLoRA] Trainable parameters:")
    model.print_trainable_parameters()
else:
    # If we couldn’t load 4-bit, we *can still* attach LoRA, but VRAM will be tighter.
    try:
        model.gradient_checkpointing_enable()
        lora_cfg = LoraConfig(
            task_type="CAUSAL_LM",
            r=16, lora_alpha=32, lora_dropout=0.05,
            target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
        )
        model = get_peft_model(model, lora_cfg)
        print("[Fallback] LoRA attached without 4-bit. Consider reducing context_length to 320–384.")
        model.print_trainable_parameters()
    except Exception as e:
        print("[Fallback] Could not attach LoRA:", e)
        print("You can still run Zero-Shot baseline. For SFT on Colab T4, 4-bit is strongly recommended.")


In [None]:
# [EDIT FROM PROF] Simple check the pipeline runs
prompt = "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?"
inputs = tokenizer(prompt, return_tensors="pt").to("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=40, do_sample=False, temperature=0.0)
text = tokenizer.decode(out[0], skip_special_tokens=True)
print(text)
print("Model reply:", text[len(prompt):])


In [None]:
# [MY ADDITION / FIX] Handles HF dict-of-lists batches + allows live model OR path (full model / PEFT adapters)

import os, re, json, torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm as tq

# Ensure outputs dir exists
try:
    OUT_DIR
except NameError:
    OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)

# CoT prompt (put this in your report)
COT_PROMPT = """{question}

Solve the problem step by step. Keep steps concise.
End with the exact final answer on a new last line as: "#### <number>"

Let's think step by step.
"""

NUM_RE = re.compile(r"####\s*(-?\d+(?:\.\d+)?)")

def extract_final(text: str):
    if not text:
        return None
    m = NUM_RE.search(text)
    return m.group(1).strip() if m else None

def batched_generate(model, tokenizer, prompts, max_new_tokens=128, temperature=0.0, device="cuda"):
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False if temperature == 0.0 else True,
            temperature=temperature,
            use_cache=True,
            eos_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=False,
        )
    texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    gens = []
    for p, t in zip(prompts, texts):
        i = t.rfind(p)
        gens.append(t[i+len(p):].strip() if i != -1 else t.strip())
    return gens

def load_model_and_tokenizer_for_eval(model_or_path, default_base=None):
    if default_base is None:
        default_base = "meta-llama/Llama-3.2-3B-Instruct"
        if "MODEL_ID" in globals() and isinstance(MODEL_ID, str):
            default_base = MODEL_ID

    if not isinstance(model_or_path, str):
        tok = globals().get("tokenizer", None)
        if tok is None:
            tok = AutoTokenizer.from_pretrained(default_base, use_fast=True)
            if tok.pad_token_id is None:
                tok.pad_token_id = tok.eos_token_id
        return model_or_path, tok

    # Try PEFT adapter folder first
    try:
        from peft import AutoPeftModelForCausalLM
        mdl = AutoPeftModelForCausalLM.from_pretrained(
            model_or_path,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        )
        base_name = None
        try:
            peft_cfg = mdl.peft_config
            if isinstance(peft_cfg, dict) and len(peft_cfg) > 0:
                for v in peft_cfg.values():
                    if hasattr(v, "base_model_name_or_path") and v.base_model_name_or_path:
                        base_name = v.base_model_name_or_path
                        break
        except:
            pass
        if base_name is None:
            base_name = default_base
        tok = AutoTokenizer.from_pretrained(base_name, use_fast=True)
        if tok.pad_token_id is None:
            tok.pad_token_id = tok.eos_token_id
        return mdl, tok
    except Exception:
        # Fallback to full model dir
        tok = AutoTokenizer.from_pretrained(model_or_path, use_fast=True)
        if tok.pad_token_id is None:
            tok.pad_token_id = tok.eos_token_id
        mdl = AutoModelForCausalLM.from_pretrained(
            model_or_path,
            device_map="auto",
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
        )
        return mdl, tok

def eval_em(model_or_path, split="test", outpath=f"{OUT_DIR}/eval.jsonl",
            batch_size=6, temperature=0.0, max_new=128, limit=None):
    ds_full = load_dataset("openai/gsm8k", "main")[split]
    N = len(ds_full) if limit is None else min(limit, len(ds_full))

    mdl, tok = load_model_and_tokenizer_for_eval(model_or_path)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    em = 0; recs = []
    for start in tq(range(0, N, batch_size), desc=f"Eval {split} ({N} ex)"):
        end = min(start + batch_size, N)
        b = ds_full[start:end]  # dict-of-lists
        if isinstance(b, dict):
            questions = b["question"]; answers = b["answer"]
        else:
            subset = ds_full.select(range(start, end))
            questions = subset["question"]; answers = subset["answer"]

        prompts = [COT_PROMPT.format(question=q) for q in questions]
        gens = batched_generate(mdl, tok, prompts, max_new_tokens=max_new, temperature=temperature, device=device)

        for q, gold_text, g in zip(questions, answers, gens):
            m = NUM_RE.search(gold_text)
            gold = m.group(1) if m else None
            pred = extract_final(g)
            ok = int(pred is not None and gold is not None and pred == gold)
            em += ok
            recs.append({"q": q, "gold": gold, "pred": pred, "gen": g, "em": ok})

    with open(outpath, "w") as f:
        for r in recs: f.write(json.dumps(r) + "\n")

    acc = em / N if N > 0 else 0.0
    print(f"EM: {acc:.4f}  ({em}/{N})  | batch_size={batch_size}, max_new={max_new}, temp={temperature}")
    return acc


In [None]:
# [MY ADDITION] Baseline EM — quick 200-sample run
baseline_em_200 = eval_em(
    model,
    split="test",
    outpath=f"{OUT_DIR}/baseline_test200.jsonl",
    batch_size=6,
    temperature=0.0,
    max_new=128,
    limit=200
)
baseline_em_200

# Full test later (longer): remove limit
# baseline_em = eval_em(model, split="test",
#                       outpath=f"{OUT_DIR}/baseline_test.jsonl",
#                       batch_size=6, temperature=0.0, max_new=128)
# baseline_em


In [None]:
# [PROFESSOR CODE] cleaned to use padding path (simpler & stable on Colab)

def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
    metrics_data = {
        "train_step_loss": train_step_loss,
        "train_epoch_loss": train_epoch_loss,
        "train_step_perplexity": train_step_ppl,
        "train_epoch_perplexity": train_epoch_ppl,
        "val_step_loss": val_step_loss,
        "val_epoch_loss": val_epoch_loss,
        "val_step_perplexity": val_step_ppl,
        "val_epoch_perplexity": val_epoch_ppl
    }
    with open(output_filename, "w") as f: json.dump(metrics_data, f)

def get_dataloader_kwargs(train_config, tokenizer, mode):
    kwargs = {}
    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
    kwargs["batch_size"] = batch_size
    kwargs["drop_last"] = True
    kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
    return kwargs

import contextlib
@contextlib.contextmanager
def profile(cfg, local_rank=None):
    if cfg.use_profiler:
        wait_step, warmup_step, active_step = 1, 2, 3
        print(f"Profiler active; saving to {cfg.profiler_dir}")
        with torch.profiler.profile(
            activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
            schedule=torch.profiler.schedule(wait=wait_step, warmup=warmup_step, active=active_step, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(cfg.profiler_dir),
            profile_memory=True, with_stack=False, with_flops=True, record_shapes=True,
        ) as torch_profiler:
            yield torch_profiler
    else:
        yield None

def evaluation(model, train_config, eval_dataloader, tokenizer, wandb_run=None):
    model.eval()
    val_step_loss, val_step_perplexity = [], []
    eval_loss = 0.0
    with MemoryTrace() as memtrace:
        for step, batch in enumerate(tqdm(eval_dataloader, colour="green", desc="evaluating", dynamic_ncols=True)):
            for k in batch: batch[k] = batch[k].to('cuda:0' if torch.cuda.is_available() else 'cpu')
            with torch.no_grad():
                outputs = model(**batch); loss = outputs.loss
                val_step_loss.append(loss.detach().float().item())
                val_step_perplexity.append(float(torch.exp(loss.detach().float())))
                eval_loss += loss.detach().float()
    eval_epoch_loss = eval_loss / max(1, len(eval_dataloader))
    eval_ppl = torch.exp(eval_epoch_loss)
    return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity

def train(model, train_dataloader, eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, wandb_run=None):
    autocast = contextlib.nullcontext
    train_prep, train_loss, val_prep, val_loss = [], [], [], []
    if train_config.save_metrics:
        metrics_filename = f"{train_config.output_dir}/metrics_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
        train_step_perplexity, train_step_loss = [], []
        val_step_loss, val_step_perplexity = [], []

    epoch_times, checkpoint_times = [], []
    best_val_loss = float("inf")
    total_train_steps = 0
    max_steps_reached = False

    for epoch in range(train_config.num_epochs):
        if max_steps_reached: break
        epoch_start = time.perf_counter()
        with MemoryTrace() as memtrace:
            model.train()
            total_loss = 0.0
            total_length = len(train_dataloader)//gradient_accumulation_steps
            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
            with profile(train_config, None) as prof:
                for step, batch in enumerate(train_dataloader):
                    total_train_steps += 1
                    if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                        max_steps_reached = True; print("Reached max_train_step; stopping."); break
                    for k in batch: batch[k] = batch[k].to('cuda:0' if torch.cuda.is_available() else 'cpu')
                    with autocast():
                        outputs = model(**batch); loss = outputs.loss
                    loss = loss / gradient_accumulation_steps

                    if train_config.save_metrics:
                        train_step_loss.append(loss.detach().float().item())
                        train_step_perplexity.append(float(torch.exp(loss.detach().float())))

                    total_loss += loss.detach().float()
                    loss.backward()

                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                        optimizer.step(); optimizer.zero_grad(); pbar.update(1)

                    if train_config.use_profiler and prof is not None: prof.step()
                    pbar.set_description(f"Epoch {epoch+1}/{train_config.num_epochs} step {step}/{len(train_dataloader)} (loss {float(loss):.4f})")
                pbar.close()

        epoch_time = time.perf_counter()-epoch_start
        epoch_times.append(epoch_time)
        train_epoch_loss = total_loss / max(1, len(train_dataloader))
        train_perplexity = torch.exp(train_epoch_loss)
        train_prep.append(float(train_perplexity)); train_loss.append(float(train_epoch_loss))
        memtrace.print_stats()
        lr_scheduler.step()

        if train_config.run_validation and eval_dataloader is not None:
            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, tokenizer, wandb_run)
            val_loss.extend([float(eval_epoch_loss)])
            val_prep.extend([float(eval_ppl)])
            ckpt_start = time.perf_counter()
            if train_config.save_model:
                epoch_dir = os.path.join(train_config.output_dir, f"epoch{epoch}")
                os.makedirs(epoch_dir, exist_ok=True)
                model.save_pretrained(epoch_dir)
                print(f"Model saved to {epoch_dir}")
            checkpoint_times.append(time.perf_counter()-ckpt_start)
            if eval_epoch_loss < best_val_loss:
                best_val_loss = eval_epoch_loss; print(f"best eval loss @ epoch {epoch+1}: {best_val_loss:.4f}")

        print(f"Epoch {epoch+1}: train_ppl={train_perplexity:.4f}, train_loss={train_epoch_loss:.4f}, time={epoch_time:.1f}s")
        if train_config.save_metrics:
            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)

    results = {
        'avg_train_prep': np.mean(train_prep) if train_prep else None,
        'avg_train_loss': np.mean(train_loss) if train_loss else None,
        'avg_eval_prep':  np.mean(val_prep)   if val_prep   else None,
        'avg_eval_loss':  np.mean(val_loss)   if val_loss   else None,
        'avg_epoch_time': np.mean(epoch_times) if epoch_times else None,
        'avg_checkpoint_time': np.mean(checkpoint_times) if checkpoint_times else None,
    }
    if train_config.save_metrics: results["metrics_filename"] = metrics_filename
    return results


In [None]:
# [EDIT FROM PROF] Train on gold reasoning from GSM8K train
def get_preprocessed_dataset(tokenizer, max_train=0):
    ds_all = load_dataset("openai/gsm8k", "main")
    train = ds_all["train"]
    if max_train and max_train > 0:
        train = train.select(range(max_train))

    # 10% val split
    train_ds, val_ds = train.train_test_split(test_size=0.1, seed=SEED).values()

    def tok(sample):
        prompt_ids = tokenizer.encode(tokenizer.bos_token + "###Input:\n" + sample["question"] + "\n", add_special_tokens=False)
        label_ids  = tokenizer.encode("###Output:\n" + sample["answer"] + tokenizer.eos_token, add_special_tokens=False)
        ids = prompt_ids + label_ids
        return {"input_ids": ids, "attention_mask": [1]*len(ids), "labels": ids}

    train_ds = train_ds.map(tok, remove_columns=list(train_ds.features))
    val_ds   = val_ds.map(tok,   remove_columns=list(val_ds.features))
    return train_ds, val_ds


In [None]:
# [EDIT FROM PROF] Colab-safe defaults; "padding" batching; QLoRA optimizer

@dataclass
class train_configy:
    model_name: str = MODEL_ID
    tokenizer_name: str = None
    run_validation: bool = False          # FAST: no val during train (enable for full runs)
    batch_size_training: int = 1
    batching_strategy: str = "padding"    # simpler & stable on Colab
    context_length: int = 384             # shorter seq helps speed/memory
    gradient_accumulation_steps: int = 2  # effective batch ~2
    gradient_clipping: bool = True
    gradient_clipping_threshold: float = 1.0
    num_epochs: int = 1
    max_train_step: int = 200             # FAST: cap steps; set 0 for full
    max_eval_step: int = 0
    num_workers_dataloader: int = 0
    lr: float = 2e-4                      # LoRA-friendly LR
    weight_decay: float = 0.0
    gamma: float = 0.85
    seed: int = SEED
    mixed_precision: bool = True
    val_batch_size: int = 1
    output_dir: str = "./content/vanilla_sft"
    save_model: bool = True
    save_metrics: bool = False
    flop_counter: bool = False
    flop_counter_start: int = 3
    use_profiler: bool = False
    profiler_dir: str = "./content/vanilla_sft"

train_config = train_configy()

dataset_train, dataset_val = get_preprocessed_dataset(tokenizer)
print(f"Train size: {len(dataset_train)} | Val size: {len(dataset_val)}")

dl_kwargs_train = get_dataloader_kwargs(train_config, tokenizer, "train")
dl_kwargs_val   = get_dataloader_kwargs(train_config, tokenizer, "val")

train_dataloader = torch.utils.data.DataLoader(
    dataset_train, num_workers=train_config.num_workers_dataloader, pin_memory=True, **dl_kwargs_train
)
eval_dataloader = torch.utils.data.DataLoader(
    dataset_val,   num_workers=train_config.num_workers_dataloader, pin_memory=True, **dl_kwargs_val
)

# Optimizer on LoRA params only
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(trainable_params, lr=train_config.lr, weight_decay=train_config.weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

print("Loaders ready. Trainable params:", sum(p.numel() for p in trainable_params))


In [None]:
results_vanilla = train(
    model,
    train_dataloader,
    eval_dataloader if train_config.run_validation else None,
    tokenizer,
    optimizer,
    scheduler,
    train_config.gradient_accumulation_steps,
    train_config,
    None,
)
[print(f'Key: {k}, Value: {v}') for k, v in results_vanilla.items()];

# Save adapters (optional; included for submission)
os.makedirs(train_config.output_dir, exist_ok=True)
model.save_pretrained(train_config.output_dir)

# Evaluate Vanilla SFT — use live model (includes LoRA adapters)
vanilla_em_200 = eval_em(
    model,
    split="test",
    outpath=f"{OUT_DIR}/vanilla_test200.jsonl",
    batch_size=6,
    temperature=0.0,
    max_new=128,
    limit=200
)
vanilla_em_200


In [None]:
# --- Recovery Prelude: (re)define dirs & tokenizer if missing ---
import os, torch
try:
    OUT_DIR
except NameError:
    OUT_DIR = "outputs"
os.makedirs(OUT_DIR, exist_ok=True)

try:
    DATA_DIR
except NameError:
    DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

# Make sure MODEL_ID & tokenizer exist
try:
    MODEL_ID
except NameError:
    MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"

try:
    tokenizer
except NameError:
    from transformers import AutoTokenizer
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id
 `

In [None]:
# ===== Cell 12 (UPDATED) — STaR bootstrapping (robust batching; uses dtype; no global in defaults) =====
import os, json, torch
from tqdm.auto import tqdm as tq
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM

# Prompt used when the first rationale was wrong; include this in your report
HINT_PROMPT = """{question}

You are given that the correct final answer is: {final_answer}
Reverse-engineer a correct, logically-sound step-by-step solution that ends with:
#### {final_answer}
"""

def build_star_data(model_or_path, out_path=None, temp=0.0, max_new=128, bsz=6, limit=None):
    # choose default save path lazily so it never errors if DATA_DIR wasn't set earlier
    if out_path is None:
        out_path = os.path.join(DATA_DIR, "star_train_iter1.jsonl")

    ds_tr_full = load_dataset("openai/gsm8k","main")["train"]
    N = len(ds_tr_full) if limit is None else min(limit, len(ds_tr_full))
    ds = ds_tr_full.select(range(N)) if N < len(ds_tr_full) else ds_tr_full

    # Accept a live model or an HF model id / local path
    if isinstance(model_or_path, str):
        tok = AutoTokenizer.from_pretrained(model_or_path, use_fast=True)
        if tok.pad_token_id is None:
            tok.pad_token_id = tok.eos_token_id
        dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
        mdl = AutoModelForCausalLM.from_pretrained(
            model_or_path,
            dtype=dtype,              # <- use dtype (torch_dtype is deprecated)
            device_map="auto",
        )
    else:
        mdl = model_or_path
        tok = tokenizer

    kept, rat = [], []
    device = "cuda" if torch.cuda.is_available() else "cpu"

    for start in tq(range(0, N, bsz), desc="STaR gen"):
        end = min(start + bsz, N)
        batch = ds[start:end]  # HF returns dict-of-lists

        if isinstance(batch, dict):
            questions = batch["question"]; answers = batch["answer"]
        else:
            subset = ds.select(range(start, end))
            questions = subset["question"]; answers = subset["answer"]

        # First pass: regular CoT
        prompts = [COT_PROMPT.format(question=q) for q in questions]
        gens = batched_generate(mdl, tok, prompts, max_new_tokens=max_new,
                                temperature=temp, device=device)

        wrong_q, wrong_gold = [], []
        for q, a, g in zip(questions, answers, gens):
            m = NUM_RE.search(a)
            gold = m.group(1) if m else None
            pred = extract_final(g)
            if gold is not None and pred == gold:
                kept.append({"question": q, "rationale": g.strip(), "final_answer": gold})
            else:
                if gold is not None:
                    wrong_q.append(q); wrong_gold.append(gold)

        # Second pass: rationalize with the correct answer
        if wrong_q:
            rprompts = [HINT_PROMPT.format(question=q, final_answer=ga) for q, ga in zip(wrong_q, wrong_gold)]
            rgens = batched_generate(mdl, tok, rprompts, max_new_tokens=max_new,
                                     temperature=temp, device=device)
            for q, ga, rg in zip(wrong_q, wrong_gold, rgens):
                pred = extract_final(rg)
                if pred == ga:
                    rat.append({"question": q, "rationale": rg.strip(), "final_answer": ga})

    rows = kept + rat
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    with open(out_path, "w") as f:
        for r in rows:
            f.write(json.dumps(r) + "\n")
    print(f"Saved {len(rows)} rows ({len(kept)} correct-first, {len(rat)} rationalized) -> {out_path}")
    return {"total": len(rows), "correct_first": len(kept), "rationalized": len(rat)}

# ---- Run STaR data build (FAST: subset; FULL: set limit=None) ----
_ = build_star_data(
    model_or_path=MODEL_ID,                     # same base you train
    out_path=os.path.join(DATA_DIR, "star_train_iter1.jsonl"),
    temp=0.0,
    max_new=128,
    bsz=6,
    limit=1000                                  # change to None for full train set
)


In [None]:
# [MY ADDITION] Turn STaR JSONL into LM text and reuse QLoRA SFT loop

def load_star_as_lm_text(jsonl_path):
    rows=[]
    with open(jsonl_path) as f:
        for line in f:
            ex = json.loads(line)
            text = f"{ex['question']}\n\n{ex['rationale']}".strip()
            rows.append({"text": text})
    return Dataset.from_list(rows)

star_ds = load_star_as_lm_text(f"{DATA_DIR}/star_train_iter1.jsonl")

def tok_row(ex):
    ids = tokenizer.encode(ex["text"], add_special_tokens=False)
    return {"input_ids": ids, "attention_mask": [1]*len(ids), "labels": ids}

star_ds = star_ds.map(tok_row, remove_columns=list(star_ds.features))
star_train, star_val = star_ds.train_test_split(test_size=0.05, seed=SEED).values()

# New output dir
star_out = "./content/star_iter1"; os.makedirs(star_out, exist_ok=True)

# Reuse padding strategy
dl_kwargs_train = get_dataloader_kwargs(train_config, tokenizer, "train")
dl_kwargs_val   = get_dataloader_kwargs(train_config, tokenizer, "val")

star_train_loader = torch.utils.data.DataLoader(star_train, num_workers=0, pin_memory=True, **dl_kwargs_train)
star_val_loader   = torch.utils.data.DataLoader(star_val,   num_workers=0, pin_memory=True, **dl_kwargs_val)

# Fresh optimizer/scheduler on LoRA params
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = optim.AdamW(trainable_params, lr=train_config.lr, weight_decay=train_config.weight_decay)
scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)

# Use same FAST settings (you can later set max_train_step=0 for full)
train_config.output_dir = star_out

results_star1 = train(
    model,
    star_train_loader,
    star_val_loader if train_config.run_validation else None,
    tokenizer,
    optimizer,
    scheduler,
    train_config.gradient_accumulation_steps,
    train_config,
    None,
)
[print(f'Key: {k}, Value: {v}') for k, v in results_star1.items()];

# Save adapters
model.save_pretrained(star_out)

# Evaluate (200-ex quick)
star1_em_200 = eval_em(
    model,  # live model with newest adapters
    split="test",
    outpath=f"{OUT_DIR}/star_iter1_test200.jsonl",
    batch_size=6,
    temperature=0.0,
    max_new=128,
    limit=200
)
star1_em_200


In [None]:
print("=== Exact Match (GSM8K test) — Summary (fast subset) ===")
try:    print(f"Zero-Shot CoT (200 ex): {baseline_em_200:.4f}")
except: print("Zero-Shot CoT (200 ex): not run yet")

try:    print(f"Vanilla SFT (200 ex):  {vanilla_em_200:.4f}")
except: print("Vanilla SFT (200 ex):  not run yet")

try:    print(f"STaR (iter 1, 200 ex): {star1_em_200:.4f}")
except: print("STaR (iter 1, 200 ex): not run yet")


In [None]:
# [MY ADDITION] Package code + outputs for submission
!mkdir -p submission/code submission/report

# Save prompts for the report
with open("submission/report/prompts.txt","w") as f:
    f.write("COT_PROMPT:\n" + COT_PROMPT + "\n\n")
    f.write("HINT_PROMPT:\n" + HINT_PROMPT + "\n")

# Copy useful outputs
!cp -r outputs submission/ 2>/dev/null || true
!cp -r content submission/checkpoints 2>/dev/null || true

# Zip it
!zip -r submission.zip submission >/dev/null
print("Created submission.zip")
