# Lesson 3 Demo Notebook — Training Strategies & Transfer

This notebook mirrors the lesson flow with small, CPU-friendly demos. Each section is short and practical.

## 0) Setup
Install optional libraries if needed. All demos are tiny; they run on CPU. If you don't want to install anything, you can still run the pure-Python parts (Sections 3, 5, 6, 9).

In [1]:
# Uncomment if running in a fresh environment
# !pip -q install torch --index-url https://download.pytorch.org/whl/cpu
# !pip -q install transformers==4.42.4 peft==0.11.1 accelerate==0.32.1 datasets==2.20.0 scikit-learn==1.4.2
# If you need a lightweight FAISS alternative for RAG demo, we'll do cosine over TF-IDF via scikit-learn.
import math, random, time, os, json
from collections import defaultdict, Counter
random.seed(0)

## 1) Strategy ladder: start with prompting → add RAG if facts matter
We show a **no-training** baseline and a tiny RAG sketch with TF‑IDF retrieval.

In [2]:
# Tiny 'corpus' and a trivial retrieval function using TF-IDF (scikit-learn) if available, else fallback.
docs = [
    "LoRA adapters let you fine-tune a small number of weights to steer a base model.",
    "RAG systems retrieve documents at query time so the model can ground its answer.",
    "Gradient accumulation simulates larger batches by summing gradients across steps.",
    "Early stopping halts training when the validation metric stops improving.",
    "Teacher forcing feeds the correct next token at train time for stability.",
]

queries = [
    "How can I adapt a base model cheaply?",
    "How do I keep answers up-to-date with sources?",
]

def simple_retriever(q, docs, k=2):
    try:
        from sklearn.feature_extraction.text import TfidfVectorizer
        from sklearn.metrics.pairwise import cosine_similarity
        tfidf = TfidfVectorizer().fit(docs+[q])
        m = tfidf.transform(docs)
        v = tfidf.transform([q])
        sims = cosine_similarity(v, m)[0]
        idx = sims.argsort()[-k:][::-1]
        return [docs[i] for i in idx], [float(sims[i]) for i in idx]
    except Exception as e:
        # Fallback: keyword overlap
        qset = set(q.lower().split())
        scored = [(sum(1 for w in d.lower().split() if w in qset), d) for d in docs]
        scored.sort(reverse=True)
        return [d for _, d in scored[:k]], None

for q in queries:
    ctx, score = simple_retriever(q, docs, k=2)
    print(f"Q: {q}\nTop docs:")
    for i,c in enumerate(ctx,1):
        print(f"  {i}. {c}")
    print("-"*60)

Q: How can I adapt a base model cheaply?
Top docs:
  1. RAG systems retrieve documents at query time so the model can ground its answer.
  2. LoRA adapters let you fine-tune a small number of weights to steer a base model.
------------------------------------------------------------
Q: How do I keep answers up-to-date with sources?
Top docs:
  1. LoRA adapters let you fine-tune a small number of weights to steer a base model.
  2. Teacher forcing feeds the correct next token at train time for stability.
------------------------------------------------------------


## 2) Teacher forcing: clean pairs & masking
We simulate SFT with prompt→target pairs and compute loss **only** on target tokens.

In [3]:
# Toy vocabulary & data
PAD, BOS, EOS = "<pad>", "<bos>", "<eos>"
vocab = [PAD,BOS,EOS,"classify","cats","dogs","are","great","ok","."]
tok2id = {t:i for i,t in enumerate(vocab)}
id2tok = {i:t for t,i in tok2id.items()}

def tokenize(seq):
    return [tok2id.get(t, tok2id[PAD]) for t in seq]

# Two clean pairs: (prompt)->(target)
pairs = [
    ([BOS, "classify", "cats", "."], ["cats","are","great",".",EOS]),
    ([BOS, "classify", "dogs", "."], ["dogs","are","ok",".",EOS]),
]

# Build a single training batch of token ids with masks: we predict only on targets.
def build_batch(pairs):
    batch_in, batch_tgt, loss_mask = [], [], []
    for prompt, target in pairs:
        x = tokenize(prompt + target)     # concat for causal setup
        # loss on positions that belong to 'target' portion only
        mask = [0]*len(prompt) + [1]*len(target)
        batch_in.append(x[:-1])
        batch_tgt.append(x[1:])
        loss_mask.append(mask[1:])  # align with next-token targets
    return batch_in, batch_tgt, loss_mask

X, Y, M = build_batch(pairs)
print("Input ids:   ", X[0])
print("Target ids:  ", Y[0])
print("Loss mask:   ", M[0])
print("Decoded demo:", [id2tok[i] for i in X[0]])

Input ids:    [1, 3, 4, 9, 4, 6, 7, 9]
Target ids:   [3, 4, 9, 4, 6, 7, 9, 2]
Loss mask:    [0, 0, 0, 1, 1, 1, 1, 1]
Decoded demo: ['<bos>', 'classify', 'cats', '.', 'cats', 'are', 'great', '.']


In [4]:
# Compute a toy cross-entropy loss where mask==1; random logits to illustrate masking effect
import numpy as np
np.random.seed(0)

vocab_size = len(vocab)
def masked_ce(logits, target_ids, mask):
    # logits: [T, V], target_ids: [T], mask: [T] (0/1)
    T = len(target_ids)
    probs = np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)
    nll = -np.log(probs[np.arange(T), target_ids] + 1e-12)
    masked = nll * np.array(mask)
    denom = max(1, sum(mask))
    return masked.sum()/denom

for i in range(len(X)):
    T = len(X[i])
    logits = np.random.randn(T, vocab_size) * 0.5  # dummy model
    loss = masked_ce(logits, Y[i], M[i])
    print(f"Example {i}: masked loss over target-only tokens = {loss:.3f}")

Example 0: masked loss over target-only tokens = 2.242
Example 1: masked loss over target-only tokens = 2.522


## 3) Transfer learning options (sketch)
Below are **minimal training stubs** showing how you would:
- attach **LoRA adapters** (lightweight)
- do **last-layer** tuning
- or **full fine-tuning**

> These are illustrative; set `steps=10` for a CPU demo. Uncomment pip installs above first.

In [5]:
# PSEUDOCODE / MINIMAL: requires transformers + peft installed to actually run
# from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
# from peft import LoraConfig, get_peft_model
# model_name = "sshleifer/tiny-gpt2"  # tiny model for demo
# tok = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_pretrained(model_name)

# --- Adapters (LoRA) ---
# peft_cfg = LoraConfig(r=8, lora_alpha=16, lora_dropout=0.05, target_modules=['c_attn'])
# model_lora = get_peft_model(model, peft_cfg)

# --- Last-layer FT --- (freeze all but final lm_head)
# for n,p in model.named_parameters():
#     p.requires_grad = ('lm_head' in n)

# --- Full FT --- (no freezing)
# for p in model.parameters():
#     p.requires_grad = True

# --- Train with small steps ---
# args = TrainingArguments(output_dir="out", per_device_train_batch_size=2, num_train_epochs=1,
#                          learning_rate=5e-4, logging_steps=1, max_steps=10)
# trainer = Trainer(model=model_lora, args=args, train_dataset=your_dataset)
# trainer.train()

## 4) Batch-size scaling without a big GPU
We compute **effective batch size** (EBS) and show gradient accumulation flow.

In [6]:
def effective_batch_size(micro_batch, accum_steps, replicas):
    return micro_batch * accum_steps * replicas

def demo_grad_accum(micro_batch=8, accum_steps=4, replicas=2):
    ebs = effective_batch_size(micro_batch, accum_steps, replicas)
    print(f"Effective batch size: {micro_batch} × {accum_steps} × {replicas} = {ebs}")
    print("Pseudo-loop:")
    for step in range(2):
        print(f"Optimizer step {step}:")
        for a in range(accum_steps):
            print(f"  micro-batch {a+1}/{accum_steps}: forward -> backward (grads accumulate)")
        print("  clip gradients -> optimizer.step() -> zero_grad()")

demo_grad_accum()

Effective batch size: 8 × 4 × 2 = 64
Pseudo-loop:
Optimizer step 0:
  micro-batch 1/4: forward -> backward (grads accumulate)
  micro-batch 2/4: forward -> backward (grads accumulate)
  micro-batch 3/4: forward -> backward (grads accumulate)
  micro-batch 4/4: forward -> backward (grads accumulate)
  clip gradients -> optimizer.step() -> zero_grad()
Optimizer step 1:
  micro-batch 1/4: forward -> backward (grads accumulate)
  micro-batch 2/4: forward -> backward (grads accumulate)
  micro-batch 3/4: forward -> backward (grads accumulate)
  micro-batch 4/4: forward -> backward (grads accumulate)
  clip gradients -> optimizer.step() -> zero_grad()


## 5) Loss spike triage (simulation)
We simulate a noisy loss curve, then apply quick fixes: lower LR and gradient clipping.

In [7]:
import numpy as np
np.random.seed(1)
steps = 60
loss = 2.0*np.exp(-np.linspace(0,3,steps)) + 0.05*np.random.randn(steps)
# Inject a 'spike'
loss[30] += 0.6

def smooth(x, k=3):
    y = np.copy(x)
    for i in range(k, len(x)):
        y[i] = 0.6*y[i] + 0.4*y[i-1]
    return y

print("Before fixes: mean={:.3f}, spike at step 30 -> {:.3f}".format(loss.mean(), loss[30]))
loss_fixed = smooth(loss)
loss_fixed[30] *= 0.7  # pretend lowering LR + clipping helped
print("After fixes:  mean={:.3f}, step 30 -> {:.3f}".format(loss_fixed.mean(), loss_fixed[30]))

Before fixes: mean=0.651, spike at step 30 -> 1.000
After fixes:  mean=0.665, step 30 -> 0.555


## 6) Decision checklist helper
A tiny function that recommends the **lowest rung** given task, data, and constraints.

In [8]:
def choose_strategy(task, data_size, labels, facts_change, latency_strict):
    # Very rough, didactic mapping
    if facts_change:  # knowledge shifts -> prefer RAG
        return "Prompt + RAG (then small adapters if formatting brittle)"
    if labels and task in {"classify","extract"}:
        if data_size in {"small","medium"}:
            return "Adapters → Last-layer if needed"
        else:
            return "Last-layer FT"
    if task in {"write","generate"}:
        if data_size == "small":
            return "Adapters"
        else:
            return "Adapters → Full FT if style still off"
    # default
    return "Prompt first; escalate as needed"

examples = [
    dict(task="classify", data_size="small", labels=True, facts_change=False, latency_strict=True),
    dict(task="qa", data_size="small", labels=False, facts_change=True, latency_strict=False),
    dict(task="write", data_size="medium", labels=False, facts_change=False, latency_strict=False),
]
for ex in examples:
    print(ex, "->", choose_strategy(**ex))

{'task': 'classify', 'data_size': 'small', 'labels': True, 'facts_change': False, 'latency_strict': True} -> Adapters → Last-layer if needed
{'task': 'qa', 'data_size': 'small', 'labels': False, 'facts_change': True, 'latency_strict': False} -> Prompt + RAG (then small adapters if formatting brittle)
{'task': 'write', 'data_size': 'medium', 'labels': False, 'facts_change': False, 'latency_strict': False} -> Adapters → Full FT if style still off


## 7) Measurement: tokens/sec, early stopping, run log
We show a simple timer-based tokens/sec, a mock early-stopping loop, and a CSV run log.

In [9]:
import time, csv, tempfile, os, math, random

def tokens_per_second(tokens, seconds):
    return tokens / max(1e-6, seconds)

# Demo timer
start = time.time(); tokens = 5000
time.sleep(0.05)  # pretend to process
tps = tokens_per_second(tokens, time.time()-start)
print(f"Tokens/sec (mock): {tps:.0f}")

# Early stopping mock: stop if no improvement after 'patience' checks by at least 'min_delta'
def early_stop_demo(metric_values, patience=3, min_delta=0.0):
    best = -1e9; bad = 0
    for i, m in enumerate(metric_values, 1):
        if m > best + min_delta:
            best, bad = m, 0
            print(f"Step {i}: new best {best:.3f}")
        else:
            bad += 1
            print(f"Step {i}: no improvement (bad={bad})")
        if bad >= patience:
            print(f"Early stop at step {i}. Best={best:.3f}")
            break

metric_vals = [0.70,0.75,0.76,0.761,0.761,0.761,0.762]
early_stop_demo(metric_vals, patience=3, min_delta=0.001)

# Minimal run log
runlog_path = "data/run_log.csv"
with open(runlog_path, "w", newline="") as f:
    w = csv.writer(f)
    w.writerow(["date","config_hash","EBS","LR","best_metric","next_action"])
    w.writerow(["2025-09-19","cfg_7b_lora_v2","64","2e-4","EM 88.1","Hold; try LR 1.5e-4"])
    w.writerow(["2025-09-19","cfg_7b_llft_v1","96","3e-4","Acc 90.2","Stop; met target"])
runlog_path

Tokens/sec (mock): 98830
Step 1: new best 0.700
Step 2: new best 0.750
Step 3: new best 0.760
Step 4: no improvement (bad=1)
Step 5: no improvement (bad=2)
Step 6: no improvement (bad=3)
Early stop at step 6. Best=0.760


'data/run_log.csv'

## 8) Summary
- Start at the lowest rung; escalate only when the metric demands.
- Keep pairs clean and masks correct.
- Scale effective batch via micro-batch × accumulation × replicas; LR ≈ scales with EBS.
- Use the triage playbook for loss spikes; measure honestly with dev set, tokens/sec, and early stop.

**Next:** swap in your real dataset and base model, and re-run the same scaffolding.