In [None]:
import torch
from huggingface_hub import login
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import Trainer, TrainingArguments
from torch.utils.data import DataLoader


device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
csqa = load_dataset("tau/commonsense_qa")
csqa

In [3]:
LETTERS = ["A","B","C","D","E"]
MAX_LEN = 576

PROMPT_TEMPLATE = (
    "You are a helpful reasoning assistant.\n"
    "Answer the multiple-choice question by outputting just one capital letter from {{A, B, C, D, E}}.\n\n"
    "Question: {question}\n"
    "Options:\n{options}\n"
    "Answer:"
)

def format_options(choice_texts):
    return "\n".join([f" {LETTERS[i]}) {t}" for i,t in enumerate(choice_texts)])

def build_prompt(example):
    return PROMPT_TEMPLATE.format(
        question=example["question"],
        options=format_options(example["choices"]["text"])
    )

In [None]:
import random
for i in random.sample(range(len(csqa["train"])), 3):
    e = csqa["train"][i]
    print(build_prompt(e))
    print("GOLD:", e["answerKey"])
    print("="*60)

In [5]:
model_id = "google/gemma-3-1b-it"

tok = AutoTokenizer.from_pretrained("google/gemma-3-1b-it", use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "right"

In [6]:
def build_item(ex, tok, max_len=MAX_LEN):
    # 1) Build the text prompt the model will read
    prompt = PROMPT_TEMPLATE.format(
        question=ex["question"],
        options=format_options(ex["choices"]["text"])
    )

    # 2) Tokenize the prompt (the "inputs")
    enc = tok(prompt, truncation=True, max_length=max_len, add_special_tokens=True)

    # 3) Build the target we want the model to generate right after "Answer:"
    #    We include a newline so decoding looks neat; it’s okay if the model emits just the letter.
    gold_letter = ex["answerKey"]                # 'A'..'E'
    target_ids  = tok(gold_letter + "\n", add_special_tokens=False)["input_ids"]

    # 4) Ensure prompt + target fits MAX_LEN (trim from the left if needed)
    spill = len(enc["input_ids"]) + len(target_ids) - max_len
    # print("spill", spill)
    if spill > 0:
        enc["input_ids"]      = enc["input_ids"][spill:]
        enc["attention_mask"] = enc["attention_mask"][spill:]

    # 5) Final input = prompt tokens + answer tokens
    input_ids     = enc["input_ids"] + target_ids
    attention     = enc["attention_mask"] + [1] * len(target_ids)

    # 6) Labels: ignore loss on the prompt (-100), supervise only the answer tokens
    labels = [-100] * len(input_ids)
    start  = len(enc["input_ids"])  # answer starts right after the prompt
    for i, t in enumerate(target_ids):
        labels[start + i] = t

    return {
        "input_ids": input_ids,
        "attention_mask": attention,
        "labels": labels,
        "answer_letter": gold_letter,  # helpful for debugging/metrics
    }

In [None]:
remove_cols = csqa["train"].column_names  # drop original text after mapping

train_ds = csqa["train"].map(lambda ex: build_item(ex, tok, MAX_LEN),
                             remove_columns=remove_cols)
val_ds   = csqa["validation"].map(lambda ex: build_item(ex, tok, MAX_LEN),
                                  remove_columns=remove_cols)
test_ds  = csqa["test"].map(lambda ex: build_item(ex, tok, MAX_LEN),
                            remove_columns=remove_cols)

train_ds, val_ds, test_ds

In [9]:
class Collator:
    def __init__(self, tok):
        self.tok = tok

    def __call__(self, feats):
        ids  = [f["input_ids"] for f in feats]
        attn = [f["attention_mask"] for f in feats]
        labs = [f["labels"] for f in feats]

        # Pad both together (tokenizer.pad wants input_ids present)
        padded = self.tok.pad(
            {"input_ids": ids, "attention_mask": attn},
            padding=True,
            return_tensors="pt",
        )
        batch_ids  = padded["input_ids"]
        batch_attn = padded["attention_mask"]

        # Pad labels to the same sequence length with -100
        L = batch_ids.size(1)
        batch_labs = torch.full((len(labs), L), -100, dtype=torch.long)
        for i, lab in enumerate(labs):
            batch_labs[i, :len(lab)] = torch.tensor(lab, dtype=torch.long)

        return {"input_ids": batch_ids, "attention_mask": batch_attn, "labels": batch_labs}

    
collator = Collator(tok)

In [None]:
base = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,   # load weights in bf16
    device_map="auto",            # place layers on GPU automatically
    trust_remote_code=True,
)
print("Loaded base model.")

In [None]:
base.gradient_checkpointing_enable()
base.config.use_cache = False

lora_cfg = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "q_proj","k_proj","v_proj","o_proj",
        "gate_proj","up_proj","down_proj",
    ],
)

model = get_peft_model(base, lora_cfg)

In [None]:
# sanity: count trainable params
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"Trainable: {trainable:,} / {total:,} ({100*trainable/total:.3f}%)")

In [None]:
# (Optional) print a few trainable names to see LoRA layers
shown = 0
for n, p in model.named_parameters():
    if p.requires_grad:
        print("TRAINABLE:", n, p.shape)
        shown += 1
    if shown >= 10:
        break

In [15]:
import torch

def compute_metrics(eval_pred):
    logits, labels = eval_pred          # logits: [B, T, V], labels: [B, T]
    logits = torch.tensor(logits)
    labels = torch.tensor(labels)

    mask = labels.ne(-100)              # supervised positions (answer tokens at the end)
    # index of the FIRST supervised token (start of the answer)
    first_idx = mask.int().argmax(dim=1)  # shape [B]

    rows = torch.arange(labels.size(0))
    # target is the first answer token id
    tgt  = labels[rows, first_idx]

    # causal shift: logits at t predict token at t+1
    # so we need the logits at (first_idx - 1) to predict the first answer token
    pred_pos = first_idx - 1

    # (safety) if any pred_pos < 0 (shouldn't happen with our construction), skip them
    valid = pred_pos.ge(0)
    if valid.sum() == 0:
        return {"accuracy": 0.0}

    rows_v = rows[valid]
    pred_v = pred_pos[valid]
    tgt_v  = tgt[valid]

    pred = logits[rows_v, pred_v, :].argmax(dim=-1)
    acc = (pred == tgt_v).float().mean().item()
    return {"accuracy": acc}


In [17]:
# optimizer.eval  = lambda: None   # <-- add this

# eval_metrics = trainer.evaluate()
# print(eval_metrics)  # should include 'eval_accuracy'

In [None]:
# Optimizer over *only* LoRA trainable params
trainable_params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.AdamW(trainable_params, lr=2e-4, betas=(0.9, 0.999), weight_decay=0.0)
optimizer.train = lambda: None   # accelerate compatibility
optimizer.eval  = lambda: None

full_args = TrainingArguments(
    output_dir="./csqa_llama32_full",
    num_train_epochs=2,                   # start with 2; you can try 3 later
    learning_rate=2e-4,
    per_device_train_batch_size = 4,
    per_device_eval_batch_size  = 8,
    gradient_accumulation_steps = 2,  # effective batch 8 “updates”
    eval_strategy="no",
    save_strategy="no",
    logging_steps=50,
    save_total_limit=2,
    lr_scheduler_type="cosine",
    warmup_ratio=0.1,
    bf16=True, fp16=False,
    gradient_checkpointing=True,
    remove_unused_columns=False,
    report_to=["none"],
    load_best_model_at_end=False,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    seed=42,
)

full_trainer = Trainer(
    model=model,
    args=full_args,
    train_dataset=train_ds,   # full train split (from Step 4C)
    eval_dataset=val_ds,      # full validation split
    data_collator=collator,
    compute_metrics=compute_metrics,  # the fixed, “causal shift” version
    optimizers=(optimizer, None),
)

full_trainer.train()

In [None]:
@torch.no_grad()
def manual_eval_accuracy(model, ds, collator, batch_size=2):  # small batch size helps
    model.eval()
    loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collator)
    device = next(model.parameters()).device
    total = correct = 0
    loss_sum = 0.0
    nloss = 0

    for batch in loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.cuda.amp.autocast(dtype=torch.bfloat16):
            out = model(**batch)  # logits [B,T,V], loss averaged over supervised tokens
        logits, labels = out.logits, batch["labels"]

        mask = labels.ne(-100)
        first_idx = mask.int().argmax(dim=1)      # start of answer
        rows = torch.arange(labels.size(0), device=device)
        tgt = labels[rows, first_idx]
        pred_pos = first_idx - 1                   # causal shift
        valid = pred_pos.ge(0)

        if valid.any():
            rows_v = rows[valid]
            pred_v = pred_pos[valid]
            tgt_v  = tgt[valid]
            pred   = logits[rows_v, pred_v, :].argmax(dim=-1)
            correct += (pred == tgt_v).sum().item()
            total   += valid.sum().item()

        if out.loss is not None and torch.isfinite(out.loss):
            loss_sum += float(out.loss)
            nloss += 1

        del logits, labels, out
        torch.cuda.empty_cache()

    acc = correct / max(total, 1)
    avg_loss = (loss_sum / nloss) if nloss > 0 else float("nan")
    return {"eval_accuracy": acc, "eval_loss_stream": avg_loss}

metrics_stream = manual_eval_accuracy(full_trainer.model, val_ds, collator, batch_size=2)
print(metrics_stream)

In [20]:
# import torch, pandas as pd
# from torch.utils.data import DataLoader

# @torch.no_grad()
# def predict_letters(model, ds, collator, batch_size=4):
#     model.eval()
#     loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collator)
#     device = next(model.parameters()).device
#     preds = []
#     for batch in loader:
#         inputs = {k: v.to(device) for k,v in batch.items() if k in ["input_ids","attention_mask"]}
#         out = model.generate(**inputs, max_new_tokens=2)
#         for i in range(out.size(0)):
#             # strip the prompt tokens; take only the generated tokens
#             prompt_len = (batch["input_ids"][i] != tok.pad_token_id).sum().item()
#             gen = tok.decode(out[i][prompt_len:], skip_special_tokens=True).strip()
#             letter = next((c for c in gen if c in LETTERS), None)
#             preds.append(letter or "")
#     return preds

# test_preds = predict_letters(full_trainer.model, test_ds, collator, batch_size=4)

# # Save with the original test ids so you can align externally if needed
# test_ids = [csqa["test"][i]["id"] for i in range(len(csqa["test"]))]
# df = pd.DataFrame({"id": test_ids, "pred": test_preds})
# df.to_csv("csqa_test_predictions.csv", index=False)
# print("Wrote csqa_test_predictions.csv")

In [None]:
save_dir = "./csqa_gemma1b_full/adapter"

# Save only the trainable LoRA layers
full_trainer.model.save_pretrained(save_dir)

# Save tokenizer too (same tok you used for CSQA)
tok.save_pretrained(save_dir)

print("Saved adapter + tokenizer to:", save_dir)

## Llama 3B

Trainable: 24,313,856 / 3,237,063,680 (0.751%)

TrainOutput(global_step=2436, training_loss=0.3055989429085517, metrics={'train_runtime': 1158.7058, 'train_samples_per_second': 16.814, 'train_steps_per_second': 2.102, 'total_flos': 2.8936654637654016e+16, 'train_loss': 0.3055989429085517, 'epoch': 2.0})

{'eval_accuracy': 0.8230958230958231, 'eval_loss_stream': 0.30398253511366224}

## Llama 1B

Trainable: 11,272,192 / 1,247,086,592 (0.904%)

TrainOutput(global_step=2436, training_loss=0.40509649430981215, metrics={'train_runtime': 670.9873, 'train_samples_per_second': 29.035, 'train_steps_per_second': 3.63, 'total_flos': 1.0019401622765568e+16, 'train_loss': 0.40509649430981215, 'epoch': 2.0})

{'eval_accuracy': 0.7592137592137592, 'eval_loss_stream': 0.4115677761792704}

## Qwen 3B

Trainable: 29,933,568 / 3,115,872,256 (0.961%)

TrainOutput(global_step=2436, training_loss=0.3242768377114595, metrics={'train_runtime': 1490.2816, 'train_samples_per_second': 13.073, 'train_steps_per_second': 1.635, 'total_flos': 2.822591263806259e+16, 'train_loss': 0.3242768377114595, 'epoch': 2.0})

{'eval_accuracy': 0.8378378378378378, 'eval_loss_stream': 0.28810480255784116}

## Qwen 1.5B

Trainable: 18,464,768 / 1,562,179,072 (1.182%)

TrainOutput(global_step=2436, training_loss=0.3098135355658132, metrics={'train_runtime': 1149.0526, 'train_samples_per_second': 16.955, 'train_steps_per_second': 2.12, 'total_flos': 1.3372783705995264e+16, 'train_loss': 0.3098135355658132, 'epoch': 2.0})

{'eval_accuracy': 0.7993447993447993, 'eval_loss_stream': 0.3620616310815679}

## Qwen 0.5B

Trainable: 8,798,208 / 502,830,976 (1.750%)

TrainOutput(global_step=2436, training_loss=0.4317081522667545, metrics={'train_runtime': 978.8359, 'train_samples_per_second': 19.903, 'train_steps_per_second': 2.489, 'total_flos': 3690345224148480.0, 'train_loss': 0.4317081522667545, 'epoch': 2.0})

{'eval_accuracy': 0.6764946764946765, 'eval_loss_stream': 0.5564539690013082}

## Gemma 4b

Trainable: 32,788,480 / 4,332,867,952 (0.757%)

TrainOutput(global_step=2436, training_loss=0.596402333092024, metrics={'train_runtime': 1788.0983, 'train_samples_per_second': 10.895, 'train_steps_per_second': 1.362, 'total_flos': 3.884991434613744e+16, 'train_loss': 0.596402333092024, 'epoch': 2.0})


{'eval_accuracy': 0.8091728091728092, 'eval_loss_stream': 0.34752246979196766}
