In [8]:
import torch
import torch.nn.functional as F
from torch import optim
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import json
import os
import re

TEMP = 0.7
MAX_NEW_TOKENS = 200
BATCH_SIZE = 4
NUM_EPOCHS = 5
TRAIN_SIZE = 5000
VAL_SIZE = 500
LOG_EVERY = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
LOG_JSON_DIR = "batch_logs"
os.makedirs(LOG_JSON_DIR, exist_ok=True)
os.makedirs(os.path.join(LOG_JSON_DIR, "train"), exist_ok=True)
os.makedirs(os.path.join(LOG_JSON_DIR, "val"), exist_ok=True)

In [9]:
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B")
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-1.7B").to(device)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.13it/s]


In [10]:
class CommonsenseQAParser:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.system_prompt = """You are an expert at applying commonsense reasoning to answer multiple-choice questions. You will be given a question with multiple answer choices, and you will be tasked with providing a brief rationale for your answer, followed by the correct answer choice. For example:
        
Q: What do people use to absorb extra ink from a fountain pen?
Answer Choices:
(a) shirt pocket
(b) calligrapher's hand
(c) inkwell
(d) desk drawer
(e) blotter
A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e).

Q: What home entertainment equipment requires cable?
Answer Choices:
(a) radio shack
(b) substation
(c) television
(d) cabinet
(e) desk
A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c).

Format your answer as:
"<BRIEF 1-2 sentence rationale>. Therefore, the answer is <answer text> (<answer letter choice>)."

Do not use any other format. If unsure, choose the most likely answer based on your reasoning.
        """

    def format_question(self, question_data):
        q = question_data['question']
        choices = "".join(
            f"({lbl.lower()}) {txt}\n"
            for lbl, txt in zip(
                question_data['choices']['label'], question_data['choices']['text']
            )
        )
        return f"Q: {q}\nAnswer Choices:\n{choices.strip()}\nA: "

    def format_prompt(self, question_data):
        messages = [
            {"role": "system",  "content": self.system_prompt},
            {"role": "user",    "content": self.format_question(question_data)}
        ]
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            enable_thinking=False
        ), messages[-1]['content']

    def parse_llm_output(self, generated_text):
        rationale = generated_text.removeprefix("</think>").strip()
        matches = re.findall(r"\(([a-e])\)", generated_text, re.IGNORECASE)
        letter = matches[-1].lower() if matches else None
        return rationale, letter

In [11]:
@torch.no_grad()
def sample_no_grad(model, batch_prompt_ids, max_new_tokens=MAX_NEW_TOKENS, temp=TEMP):
    # `batch_prompt_ids` is shape (B, T)
    seq = model.generate(
        batch_prompt_ids,
        max_new_tokens=max_new_tokens,
        temperature=temp,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )
    # We return only the newly generated portion:
    return seq[:, batch_prompt_ids.size(1):]  # shape (B, new_T)

def compute_logprobs(model, batch_prompt_ids, batch_gen_ids, temp=TEMP):
    # batch_prompt_ids: (B, T)
    # batch_gen_ids:    (B, new_T)
    full_ids = torch.cat([batch_prompt_ids, batch_gen_ids], dim=1)  # (B, T+new_T)
    full_logits = model(full_ids).logits / temp  # (B, T+new_T, V)
    full_logprobs = F.log_softmax(full_logits, dim=-1) # (B, T+new_T, V)

    # log-probs of actual tokens (i.e., log P(token_t | token_<t))
    token_logprobs = full_logprobs[:, :-1, :].gather(
        2, full_ids[:, 1:].unsqueeze(-1)
    ).squeeze(-1)  # (B, T_total - 1)

    prompt_lens = (batch_prompt_ids != tokenizer.pad_token_id).sum(dim=1)  # (B,)
    gen_lens = (batch_gen_ids != tokenizer.pad_token_id).sum(dim=1) # (B,)

    gen_lens = (batch_gen_ids != tokenizer.pad_token_id).sum(dim=1) # (B,)
    logprobs = []
    for i in range(batch_prompt_ids.size(0)):
        start = int(prompt_lens[i].item()) - 1
        gen_len = int(gen_lens[i].item())
        logprobs.append(token_logprobs[i, start : start + gen_len].sum())

    return torch.stack(logprobs, dim=0)  # (B,)

def compute_binary_reward(final_answer, correct_answer, question=None, rationale=None):
    return 1.0 if final_answer == correct_answer else 0.0

In [12]:
def get_batches(dataset, parser, batch_size):
    # Returns (prompt_strs, raw_qs, correct_answers), each a len-B List[str].
    for i in range(0, len(dataset), batch_size):
        batch_dict = dataset[i : i + batch_size]
        batch_items = [
            {key: batch_dict[key][i] for key in batch_dict}
            for i in range(len(batch_dict["id"]))
        ]
        prompt_strs, raw_qs, correct_keys = [], [], []
        for item in batch_items:
            p_str, raw_q = parser.format_prompt(item)
            prompt_strs.append(p_str)
            raw_qs.append(raw_q)
            correct_keys.append(item["answerKey"].lower())
        yield prompt_strs, raw_qs, correct_keys

In [None]:
train_dataset = load_dataset("commonsense_qa", split="train[:5000]")
val_dataset = load_dataset("commonsense_qa", split="validation[:150]")

writer = SummaryWriter()

parser = CommonsenseQAParser(tokenizer)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

for epoch in range(NUM_EPOCHS):
    ##### TRAINING LOOP #####
    model.train()
    train_loss = 0.0
    batch_count = 0
    for batch_idx, (prompt_strs, raw_qs, correct_answers) in enumerate(tqdm(
        get_batches(train_dataset, parser, BATCH_SIZE),
        desc=f"Training Epoch {epoch + 1}", total=TRAIN_SIZE // BATCH_SIZE
    )):
        # tokenize and left-pad prompt strings
        prompt_ids = tokenizer(prompt_strs, return_tensors="pt",
            padding=True, truncation=True, max_length=512
        ).to(device)["input_ids"] # (B, T)
        
        # sample responses (no grad)
        with torch.no_grad():
            gen_ids = sample_no_grad(
                model, prompt_ids,
                max_new_tokens=MAX_NEW_TOKENS, temp=TEMP
            )  # (B, new_T)
            
        # decode and parse generated outputs
        gen_strs = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) # len-B List[str]
        rationales, pred_answers = zip( # each a len-B List[str]
            *[parser.parse_llm_output(gen_str) for gen_str in gen_strs]
        )
        logprobs = compute_logprobs(model, prompt_ids, gen_ids, temp=TEMP)  # (B,)
        rewards = torch.tensor([
            compute_binary_reward(ans, corr)
            for ans, corr in zip(pred_answers, correct_answers)
        ], device=device)
        loss = -(rewards * logprobs).mean()  # (B,)
        if batch_count % LOG_EVERY == 0:
            writer.add_scalar("BatchLoss/train", loss.item(), batch_count)
            writer.add_scalar("BatchAvgLogProb/train", logprobs.mean().item(), batch_count)
            writer.add_scalar("BatchAvgReward/train", rewards.mean().item(), batch_count)
            log_data = {
                "epoch": epoch + 1,
                "batch": batch_count,
                "question": raw_qs[0],
                "rationale": rationales[0],
                "predicted_answer": pred_answers[0] if pred_answers[0] else "None",
                "correct_answer": correct_answers[0],
                "reward": float(rewards[0].item()),
                "logprob": float(logprobs[0].item()),
            }
            json_path = os.path.join(LOG_JSON_DIR, "train", f"epoch{epoch + 1}_batch{batch_count}.json")
            with open(json_path, "w") as f:
                json.dump(log_data, f, indent=2)


        train_loss += loss.item()
        batch_count += 1

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ### VALIDATION LOGGING
        if batch_count % (LOG_EVERY * 5) == 0:
            model.eval()
            val_rewards_all, val_logprobs_all, val_losses_all = [], [], []
            collected_examples = []

            for val_prompt_strs, val_raw_qs, val_correct_answers in get_batches(val_dataset, parser, BATCH_SIZE):
                val_prompt_ids = tokenizer(val_prompt_strs, return_tensors="pt",
                    padding=True, truncation=True, max_length=512
                ).to(device)["input_ids"]

                val_gen_ids = sample_no_grad(model, val_prompt_ids)
                val_gen_strs = tokenizer.batch_decode(val_gen_ids, skip_special_tokens=True)
                val_rationales, val_pred_answers = zip(
                    *[parser.parse_llm_output(gen) for gen in val_gen_strs]
                )
                val_logp = compute_logprobs(model, val_prompt_ids, val_gen_ids)
                val_rwds = torch.tensor([
                    compute_binary_reward(ans, corr)
                    for ans, corr in zip(val_pred_answers, val_correct_answers)
                ], device=device)

                val_loss = -(val_rwds * val_logp).mean()

                val_rewards_all.append(val_rwds)
                val_logprobs_all.append(val_logp)
                val_losses_all.append(val_loss)

                for q, r, a, p, rew, lp in zip(
                    val_raw_qs, val_rationales, val_correct_answers, val_pred_answers, val_rwds, val_logp
                ):
                    collected_examples.append({
                        "question": q,
                        "rationale": r,
                        "predicted_answer": p if p else "None",
                        "correct_answer": a,
                        "reward": float(rew.item()),
                        "logprob": float(lp.item())
                    })

            writer.add_scalar("Eval/AvgReward", torch.cat(val_rewards_all).mean().item(), batch_count)
            writer.add_scalar("Eval/AvgLogProb", torch.cat(val_logprobs_all).mean().item(), batch_count)
            writer.add_scalar("Eval/AvgLoss", torch.stack(val_losses_all).mean().item(), batch_count)

            # Random sample of examples
            log_sample = random.sample(collected_examples, k=5)
            log_json = {
                "epoch": epoch + 1,
                "batch": batch_count,
                "questions": [ex["question"] for ex in log_sample],
                "rationales": [ex["rationale"] for ex in log_sample],
                "predicted_answers": [ex["predicted_answer"] for ex in log_sample],
                "correct_answers": [ex["correct_answer"] for ex in log_sample],
                "rewards": [ex["reward"] for ex in log_sample],
                "logprobs": [ex["logprob"] for ex in log_sample],
            }
            eval_json_path = os.path.join(LOG_JSON_DIR, "val", f"epoch_{epoch + 1}_batch_{batch_count}.json")
            os.makedirs(os.path.dirname(eval_json_path), exist_ok=True)
            with open(eval_json_path, "w") as f:
                json.dump(log_json, f, indent=2)

            model.train()


    avg_loss = train_loss / batch_count
    print(f"Epoch {epoch + 1} avg train loss per batch: {avg_loss:.4f}")
    writer.add_scalar("Loss/train", avg_loss, epoch + 1)

writer.close()

Training Epoch 1:   1%|          | 13/1250 [01:04<2:17:32,  6.67s/it]