# AutoMode GSM8K Fine-Tuning
This notebook implements a dynamic fine-tuning strategy for GSM8K using Gemma-2B.
It supports:
- Full Fine-Tuning
- LoRA
- Dynamic Gradient Norm (Hybrid)
- BitFit
- Top-K Fine-Tuning

## 1. Environment Setup
Install necessary libraries including `transformers`, `peft`, `bitsandbytes`, and `accelerate`.

In [None]:
!pip install -q "transformers>=4.45.0" "datasets" "accelerate" "einops" "sentencepiece" "protobuf==3.20.3"
!pip install -q "huggingface_hub[cli]" "peft" "bitsandbytes"

## 2. Configuration & Imports
Define Global Configuration and Experiment Settings.

In [None]:
import re
from collections import Counter
import math, torch, json, os, time
from tqdm.auto import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, get_scheduler
from bitsandbytes.optim import AdamW8bit
from peft import LoraConfig, get_peft_model, TaskType
from peft.tuners.lora import LoraLayer
from torch.utils.data import DataLoader
from dataclasses import dataclass
from typing import List, Dict, Any
from collections import defaultdict
import numpy as np
import random
from huggingface_hub import login


SAVE_PATH = "/home/jupyter/GSM8K_Results_dec9"
os.makedirs(SAVE_PATH, exist_ok=True)
DATA_CACHE_PATH = f"{SAVE_PATH}/gsm8k_tokenized.arrow"

GSM8K_CONFIG = {
    "model_checkpoint": "google/gemma-2-2b",
    "seed": 42,
    "max_input_length": 2048,

    # Training
    "batch_train": 1,
    "batch_eval": 8,
    "learning_rate": 5e-5,
    "epochs": 3,
    "weight_decay": 0.01,
    "warmup_ratio": 0.03,
    "grad_accum": 16,

    # Strategy options
    "strategy": "dynamic_grad_norm",  # "full_ft", "lora", "dynamic_grad_norm"

    # Dynamic FT
    "dynamic_updates": 6,
    "dynamic_threshold": 10,

    # Misc
    "fp16": True,

    # Generation for evaluation
    "gen_max_tokens": 128,
    "sampling_k": 5,
}

gradient_accumulator = defaultdict(lambda: (0.0, 0))
freezing_log = []
model_name = GSM8K_CONFIG["model_checkpoint"]

## 3. Experiment Grid
Define the set of experiments to run.

In [None]:
import hashlib
import pandas as pd

RESULTS_CSV = f"{SAVE_PATH}/experiments_log.csv"

EXPERIMENT_GRID = [
    {"strategy": "dynamic_grad_norm", "dynamic_updates": 6, "dynamic_threshold": 10, "sampling_k": 5, "epochs": 2},
    {"strategy": "dynamic_grad_norm", "dynamic_updates": 6, "dynamic_threshold": 25, "sampling_k": 5, "epochs": 2},
    {"strategy": "dynamic_grad_norm", "dynamic_updates": 10, "dynamic_threshold": 10, "sampling_k": 5, "epochs": 2},
    {"strategy": "dynamic_grad_norm", "dynamic_updates": 10, "dynamic_threshold": 25, "sampling_k": 5, "epochs": 2},
    {"strategy": "lora",               "sampling_k": 5, "epochs": 2},
    {"strategy": "full_ft",            "sampling_k": 5, "epochs": 2},
]

## 4. Reproducibility
Set random seeds for consistency.

In [None]:
torch.manual_seed(GSM8K_CONFIG["seed"])
random.seed(GSM8K_CONFIG["seed"])
np.random.seed(GSM8K_CONFIG["seed"])

## 5. Utility Functions
Helper functions for answer extraction, prompt building, logging, and result persistence.

In [None]:
# ---------- Answer Extractor ----------
def extract_final_answer(text: str) -> str:
    """
    Extracts the final numerical answer from the model output.
    Looks for '#### <number>' pattern or takes the last number found.
    """
    m = re.search(r"####\s*([-\d,\.]+)", text)
    if m:
        return m.group(1).replace(",", "").strip()
    nums = re.findall(r"[-]?\d[\d,\.]*", text)
    return nums[-1].replace(",", "") if nums else ""

# ---------- Instruction ----------
INSTRUCTION = (
    "You are an expert grade-school math tutor. "
    "Solve the problem step by step, then give the final numeric answer on "
    "a separate line as:\n#### <NUMERIC_ANSWER>\n\nQuestion:\n"
)

def build_prompt(q: str) -> str:
    return INSTRUCTION + q + "\n\nAnswer:\n"

# ---------- Logging ----------
def timestamp():
    return time.strftime("%Y%m%d-%H%M%S")

def save_json(obj, path):
    with open(path, "w") as f:
        json.dump(obj, f, indent=2)
    print(f"üíæ Saved: {path}")

def make_exp_id(config: dict):
    cfg = dict(sorted(config.items()))
    s = json.dumps(cfg, sort_keys=True)
    return hashlib.md5(s.encode()).hexdigest()[:10]

def load_results():
    if os.path.exists(RESULTS_CSV):
        return pd.read_csv(RESULTS_CSV)
    return pd.DataFrame()

def append_result(row: dict):
    df = load_results()
    df = pd.concat([df, pd.DataFrame([row])], ignore_index=True)
    df.to_csv(RESULTS_CSV, index=False)
    print(f"üìå Logged exp_id={row['exp_id']} to CSV: {RESULTS_CSV}")

## 6. Authentication
Login to Hugging Face Hub (required for accessing Gemma models).

In [None]:
HF_TOKEN = ""  # Add your Hugging Face token here
login(token=HF_TOKEN)

## 7. Model & Fine-Tuning Strategies
Core logic for:
- Gradient Accumulation (for dynamic strategies)
- Model Initialization (Full FT, LoRA, etc.)
- Dynamic Layer Freezing/Unfreezing
- BitFit and Top-K implementations

In [None]:
def get_layer_name_from_param(param_name: str) -> str:
    """
    For GemmaForCausalLM (and LoRA-wrapped versions), param names look like:
    'model.layers.0.self_attn.q_proj.lora_A.default.weight'
    We group by 'model.layers.N'.
    """
    parts = param_name.split(".")
    if "layers" in parts:
        idx = parts.index("layers")
        if idx + 1 < len(parts):
            return ".".join(parts[:idx+2])  # 'model.layers.0'
    return "other_params"

def accumulate_gradients(model):
    """
    Accumulate squared gradient norms per logical layer.
    """
    for name, param in model.named_parameters():
        if param.grad is not None and param.requires_grad:
            layer_name = get_layer_name_from_param(name)
            if layer_name != "other_params":
                current_norm_sq, current_count = gradient_accumulator[layer_name]
                new_norm_sq = current_norm_sq + (torch.norm(param.grad, p=2).item() ** 2)
                new_count = current_count + param.numel()
                gradient_accumulator[layer_name] = (new_norm_sq, new_count)

def get_layer_name_from_module(module_name: str) -> str:
    return get_layer_name_from_param(module_name)

In [None]:
def get_model(strategy: str) -> torch.nn.Module:
    """
    strategy: 'full_ft', 'lora', 'dynamic_grad_norm'
    """
    print(f"Loading Gemma-2B with strategy = {strategy}")
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16 if GSM8K_CONFIG["fp16"] and torch.cuda.is_available() else torch.float32,
        device_map="cuda",
        low_cpu_mem_usage=True,
    )

    # gradient checkpointing to save VRAM
    base_model.gradient_checkpointing_enable()
    base_model.enable_input_require_grads()

    if strategy in ["full_ft",'bitfit','topk_full']:
        return base_model

    # LoRA configs (q_proj + v_proj only)
    lora_cfg = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        inference_mode=False,
        r=16,
        lora_alpha=32,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
    )

    model = get_peft_model(base_model, lora_cfg)
    if hasattr(model, "lm_head"):
      for p in model.lm_head.parameters():
          p.requires_grad = True
    model.print_trainable_parameters()

    # For 'lora' we keep base frozen, LoRA trainable.
    # For 'dynamic_grad_norm' we start in LoRA-only mode,
    # then selectively merge/unmerge via update_frozen_layers_HYBRID.
    return model

In [None]:
def update_frozen_layers_HYBRID(model, threshold_percentile: float, global_step: int):
    """
    Dynamic Hybrid Fine-Tuning for Gemma CausalLM:
      - Computes average grad norm per layer
      - Layers above threshold -> Full-FT
      - Layers below threshold -> LoRA-Frozen

    Logs:
      - step
      - threshold
      - per-layer norm + action
      - current trainable params
    """
    if not gradient_accumulator:
        return False

    # 1. average gradient norms
    avg_layer_norms = {}
    for layer, (sum_sq_norm, param_count) in gradient_accumulator.items():
        if param_count > 0:
            avg_norm = (sum_sq_norm / (param_count + 1e-9)) ** 0.5
            avg_layer_norms[layer] = avg_norm

    if not avg_layer_norms:
        return False

    norms = list(avg_layer_norms.values())
    threshold_val = np.percentile(norms, threshold_percentile)

    print(f"\n--- Dynamic Hybrid Check @ Step {global_step+1} ---")
    print(f"Average grad-norm {threshold_percentile}th percentile: {threshold_val}")

    target_state_map = {}
    log_entry = {"step": global_step + 1, "threshold": threshold_val, "layers": {}}

    for layer_name, avg_norm in avg_layer_norms.items():
        is_full_ft_target = (avg_norm >= threshold_val)
        target_state = "full_ft" if is_full_ft_target else "lora_frozen"
        target_state_map[layer_name] = target_state
        log_entry["layers"][layer_name] = {"norm": avg_norm, "action": target_state}

    params_changed = False

    # 2. apply changes over LoRA modules
    for module_name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            layer_name = get_layer_name_from_module(module_name)
            if layer_name in target_state_map:
                target_state = target_state_map[layer_name]
                # base layer trainability indicates full-ft or not
                is_currently_full_ft = next(module.get_base_layer().parameters()).requires_grad

                if target_state == "full_ft" and not is_currently_full_ft:
                    print(f"Switch {layer_name} -> Full-FT (merge LoRA)")

                    # merge LoRA to base
                    module.merge()

                    # freeze LoRA params
                    for p in module.lora_A.parameters():
                        p.requires_grad = False
                    for p in module.lora_B.parameters():
                        p.requires_grad = False

                    # unfreeze base
                    for p in module.get_base_layer().parameters():
                        p.requires_grad = True

                    params_changed = True

                elif target_state == "lora_frozen" and is_currently_full_ft:
                    print(f"Switch {layer_name} -> LoRA-Frozen (reset adapters)")

                    # freeze base
                    for p in module.get_base_layer().parameters():
                        p.requires_grad = False

                    # unfreeze LoRA
                    for p in module.lora_A.parameters():
                        p.requires_grad = True
                    for p in module.lora_B.parameters():
                        p.requires_grad = True

                    # reset LoRA parameters
                    module.reset_lora_parameters("default", True)

                    params_changed = True

    if params_changed:
        current_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
        log_entry["current_trainable_params"] = current_trainable
        print(f"Current trainable params: {current_trainable:,}")

    gradient_accumulator.clear()
    freezing_log.append(log_entry)

    return params_changed

In [None]:
# ============ GSM8K ‚Äî BitFit and Top-k Full Fine-Tune Utils ============

def apply_bitfit_gsm(model):
    """
    BitFit for Causal-LM GSM8K: train only bias terms + LM head.
    """
    # 1. Freeze all parameters
    for p in model.parameters():
        p.requires_grad = False

    # 2. Unfreeze bias terms
    for name, p in model.named_parameters():
        last = name.lower().split(".")[-1]
        if last == "bias":
            p.requires_grad = True

    # 3. Ensure the LM head is trainable (decoder head)
    if hasattr(model, "lm_head"):
        for p in model.lm_head.parameters():
            p.requires_grad = True

    print("üü¶ Applied BitFit for GSM8K: train only bias + lm_head.")


def apply_topk_full_ft_gsm(model, k):
    """
    Top-k transformer blocks are fully trainable, rest frozen, + LM head.
    Works for LLaMA/Gemma families (AutoModelForCausalLM).
    """
    # 1. Freeze everything
    for p in model.parameters():
        p.requires_grad = False

    # 2. Locate transformer blocks
    if hasattr(model, "model") and hasattr(model.model, "layers"):
        layers = model.model.layers
    elif hasattr(model, "transformer") and hasattr(model.transformer, "h"):
        layers = model.transformer.h  # GPT-style
    else:
        raise ValueError("‚ùå Could not locate transformer layers for Top-k FT")

    # 3. Select top-k blocks
    n = len(layers)
    k = min(k, n)
    selected = list(range(n - k, n))

    print(f"üü• Top-k Full FT: Unfreezing transformer blocks {selected}")

    for i in selected:
        for p in layers[i].parameters():
            p.requires_grad = True

    # 4. Always unfreeze LM head
    if hasattr(model, "lm_head"):
        for p in model.lm_head.parameters():
            p.requires_grad = True

    print("üü• Applied Top-k Full FT for GSM8K.")

## 8. Data Preparation
- `DataCollatorForCausalLM`: Handles padding and label masking.
- `load_gsm8k`: Loads and tokenizes the dataset, with caching.

In [None]:
# ---------- Data Preparation ----------
@dataclass
class DataCollatorForCausalLM:
    tokenizer: Any

    def __call__(self, features):
        # üßπ Remove text fields before batching
        text_fields = ["question", "full_answer_text", "answer_str", "answer"]
        features_for_model = [
            {k: v for k, v in f.items() if k not in text_fields}
            for f in features
        ]

        # Pad input + attention mask
        batch = self.tokenizer.pad(
            {k: [f[k] for f in features_for_model] if k != "labels" else None
             for k in features_for_model[0] if k != "labels"},
            padding="longest",
            return_tensors="pt",
        )

        # Pad labels separately, and apply ignore index
        labels = self.tokenizer.pad(
            {"input_ids": [f["labels"] for f in features_for_model]},
            padding="longest",
            return_tensors="pt"
        )["input_ids"]
        labels[labels == self.tokenizer.pad_token_id] = -100

        batch["labels"] = labels
        return batch


from datasets import Dataset, load_from_disk

def load_gsm8k(tokenizer):
    """
    Loads GSM8K with tokenization cache on disk.
    Returns: tokenized dataset (train/test)
    """
    # Check cache
    if os.path.exists(DATA_CACHE_PATH):
        print(f"‚ö° Loading cached tokenized dataset from {DATA_CACHE_PATH}")
        return load_from_disk(DATA_CACHE_PATH)

    print("‚è≥ Tokenizing GSM8K dataset for the first time‚Ä¶ (will be cached)")

    dataset = load_dataset("gsm8k", "main")

    def proc(batch):
        texts, answers = [], []
        for q,a in zip(batch["question"], batch["answer"]):
            texts.append(build_prompt(q) + a)
            answers.append(extract_final_answer(a))

        enc = tokenizer(
            texts,
            max_length=GSM8K_CONFIG["max_input_length"],
            truncation=True,
        )
        enc["labels"] = enc["input_ids"].copy()
        enc["answer_str"] = answers

        # Keep raw text (so evaluation can access it)
        enc["question"] = batch["question"]
        enc["full_answer_text"] = batch["answer"]
        return enc

    tokenized = dataset.map(proc, batched=True)

    print(f"üíæ Saving tokenized dataset to cache: {DATA_CACHE_PATH}")
    tokenized.save_to_disk(DATA_CACHE_PATH)

    return tokenized

## 9. Training Loop
The main training function `train_model` which:
1. Loads the model and dataset.
2. Applies the selected strategy.
3. Runs the training loop with gradient accumulation.
4. Triggers dynamic updates if enabled.

In [None]:
# ---------- Train Function ----------
def train_model(exp_id):
    start = time.time()
    tok = AutoTokenizer.from_pretrained(GSM8K_CONFIG["model_checkpoint"])
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    ds = load_gsm8k(tok)
    raw_ds = load_dataset("gsm8k", "main")
    sample_raw = raw_ds["train"][0]

    train_dl = DataLoader(
        ds["train"],
        batch_size=GSM8K_CONFIG["batch_train"],
        shuffle=True,
        collate_fn=DataCollatorForCausalLM(tok),
    )

    # LOAD MODEL (shortened)
    strategy = GSM8K_CONFIG["strategy"]
    model = get_model(strategy)
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())
    print(f"üîß Trainable Params: {trainable:,} / {total:,}  ({trainable/total*100:.2f}%)\n")

    print("üîß Example Trainable Parameter Names:")
    for name, p in list(model.named_parameters())[:5]:
        if p.requires_grad:
            print("  üè∑Ô∏è", name)

    if strategy == "bitfit":
        apply_bitfit_gsm(model)

    elif strategy == "topk_full":
        apply_topk_full_ft_gsm(model, k=4)  # or number you choose

    model.train()

    steps = math.ceil(len(train_dl) * GSM8K_CONFIG["epochs"] / GSM8K_CONFIG["grad_accum"])
    optim = AdamW8bit([p for p in model.parameters() if p.requires_grad],
                      lr=GSM8K_CONFIG["learning_rate"])
    sched = get_scheduler(
            name="cosine",
            optimizer=optim,
            num_warmup_steps=int(GSM8K_CONFIG["warmup_ratio"] * steps),
            num_training_steps=steps,
        )

    progress = tqdm(range(steps))
    g = 0
    opt_step = 0
    for epoch in range(GSM8K_CONFIG["epochs"]):
      in_epoch_steps = 0
      print(f"\n=== üèãÔ∏è Epoch {epoch+1}/{GSM8K_CONFIG['epochs']} ===")
      for batch in train_dl:
          batch = {k: v.cuda() for k,v in batch.items() if isinstance(v, torch.Tensor)}
          loss = model(**batch).loss / GSM8K_CONFIG["grad_accum"]
          loss.backward()

          # Only accumulate dynamic grads AFTER backward
          if strategy == "dynamic_grad_norm":
              accumulate_gradients(model)

          g += 1


          # Only trigger on optimizer steps
          if g % GSM8K_CONFIG["grad_accum"] == 0:
              torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
              optim.step()
              sched.step()
              optim.zero_grad()

              opt_step += 1  # <-- count real steps
              in_epoch_steps += 1

              progress.update(1)
              progress.set_description(f"loss={loss.item():.4f}")

              # Dynamic Logic ON REAL STEPS ONLY
              if strategy == "dynamic_grad_norm":
                  interval = max(1, (steps/GSM8K_CONFIG["epochs"]) // GSM8K_CONFIG["dynamic_updates"])
                  if (in_epoch_steps % interval == 0) or (in_epoch_steps == steps/GSM8K_CONFIG["epochs"]):
                      changed = update_frozen_layers_HYBRID(
                          model, GSM8K_CONFIG["dynamic_threshold"], opt_step
                      )
                      if changed:
                          optim = AdamW8bit([p for p in model.parameters() if p.requires_grad],
                                            lr=GSM8K_CONFIG["learning_rate"])
                          sched = get_scheduler(
                              name="cosine", optimizer=optim,
                              num_warmup_steps=max(1, int(GSM8K_CONFIG["warmup_ratio"] * steps)),
                              num_training_steps=steps,
                          )
    runtime = time.time()-start
    save_json(freezing_log, f"{SAVE_PATH}/dynamic_log_{exp_id}.json")
    print(f"‚è±Ô∏è Training time: {runtime:.2f}s")
    return model, tok,ds, runtime

## 10. Evaluation
`evaluate`: Generates answers for the test set and calculates majority voting accuracy (maj@1).

In [None]:
def evaluate(model, tok, ds_test):
    model.eval()
    B = GSM8K_CONFIG["batch_eval"]
    k = GSM8K_CONFIG["sampling_k"]
    max_t = GSM8K_CONFIG["gen_max_tokens"]

    logs = []
    correct = 0
    N = len(ds_test)
    for i in tqdm(range(0, N, B), desc="Batch maj@1"):
        batch = ds_test[i:i+B]
        prompts = [build_prompt(q) for q in batch["question"]]
        inp = tok(prompts, return_tensors="pt", padding=True).to(model.device)

        with torch.no_grad():
            out = model.generate(
                **inp, max_new_tokens=max_t, do_sample=True,
                temperature=0.7, top_p=0.9,
                num_return_sequences=k,
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )

        # unpack generations
        seqs = out.reshape(len(prompts), k, -1)
        for j,prompt in enumerate(prompts):
            gold = extract_final_answer(batch["answer"][j])
            texts = [tok.decode(seqs[j][l][inp["input_ids"].shape[1]:],
                                skip_special_tokens=True)
                     for l in range(k)]
            preds = [extract_final_answer(t) for t in texts]
            maj = max(set(preds), key=preds.count)
            correct += (maj==gold)
            logs.append({"question": batch["question"][j],
                         "gold": gold, "samples": preds, "maj_pred": maj})
            if i%50==0:
              print("\nüìå Extracted Predictions:", preds)
              print("üèÜ Majority Prediction:", maj)
              print("Gold:", gold)
              print("====================================\n")
    acc = correct/N
    path = f"{SAVE_PATH}/full_eval_{acc:.3f}.json"
    save_json(logs, path)

    print(f"\nmaj@1={acc*100:.2f}%")
    return acc, path

## 11. Experiment Runner
Orchestrates the experiments defined in the grid, running training and evaluation for each configuration.

In [None]:
from copy import deepcopy
import json
def run_experiment(config):
    global GSM8K_CONFIG, gradient_accumulator, freezing_log

    # Reset dynamic global state
    gradient_accumulator = defaultdict(lambda: (0.0, 0))
    freezing_log = []

    # Apply config values to main config
    for k, v in config.items():
        GSM8K_CONFIG[k] = v

    exp_id = make_exp_id(config)
    print(f"\nüöÄ Running experiment {exp_id} with config={config}")

    # === TRAIN ===
    model, tok, ds, runtime = train_model(exp_id)

    # === EVAL (subset. change to full test when stable) ===
    acc, _ = evaluate(model, tok, ds["test"].select(range(400)))

    # === COUNT PARAMS ===
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total = sum(p.numel() for p in model.parameters())

    # === LOG ROW ===
    result = {
        "exp_id": exp_id,
        "strategy": GSM8K_CONFIG["strategy"],
        "dynamic_updates": GSM8K_CONFIG.get("dynamic_updates", None),
        "learning_rate": GSM8K_CONFIG["learning_rate"],
        "dynamic_threshold": GSM8K_CONFIG.get("dynamic_threshold", None),
        "sampling_k": GSM8K_CONFIG.get("sampling_k", None),
        "epochs": GSM8K_CONFIG.get("epochs", None),
        "trainable_params": trainable,
        "total_params": total,
        "trainable_pct": trainable / total,
        "maj@1": acc,
        "runtime_sec": runtime,
        "timestamp": timestamp()
    }
    append_result(result)

    return result

## 12. Main Execution
Iterate through the `EXPERIMENT_GRID` and run all experiments.

In [None]:
for cfg in EXPERIMENT_GRID:
    row = run_experiment(cfg)
    print("‚úîÔ∏è Finished:", row)