In [1]:
import torch
import torch.nn.functional as F
from torch import optim
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import json
import os
import re
import random

TEMP = 0.7
MAX_NEW_TOKENS = 256
BATCH_SIZE = 4
NUM_EPOCHS = 5
TRAIN_SIZE = 5000
VAL_SIZE = 500
LOG_EVERY = 10
device = "cuda" if torch.cuda.is_available() else "cpu"
LOG_JSON_DIR = "batch_logs_test"
os.makedirs(LOG_JSON_DIR, exist_ok=True)
os.makedirs(os.path.join(LOG_JSON_DIR, "train"), exist_ok=True)
os.makedirs(os.path.join(LOG_JSON_DIR, "val"), exist_ok=True)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
model_name = "meta-llama/Llama-3.2-3B-Instruct" #"Qwen/Qwen2.5-1.5B-Instruct" #"Qwen/Qwen3-1.7B" #"Qwen/Qwen2.5-1.5B-Instruct" #Qwen/Qwen3-1.7B
tokenizer = AutoTokenizer.from_pretrained(model_name) #use_auth_token=hf_token)
model = AutoModelForCausalLM.from_pretrained(model_name) #use_auth_token=hf_token).to(device)
if tokenizer.pad_token is None:
    #tokenizer.add_special_tokens({'pad_token': '[PAD]'})
    tokenizer.pad_token = "<|reserved_special_token_4|>"
    # tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

Loading checkpoint shards: 100%|██████████| 2/2 [00:01<00:00,  1.20it/s]


In [3]:
tokenizer.pad_token 

'<|reserved_special_token_4|>'

In [3]:
# Test base model generation quality 
model.eval()
sample_prompt = "Q: Where is a bald eagle safe?\nAnswer Choices:\n(a) pine tree\n(b) open country\n(c) in washington\n(d) wildlife refuge\n(e) sky\nA: "

with torch.no_grad():
    inputs = tokenizer(sample_prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=256,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_type_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("Generated:", result)

Generated: Q: Where is a bald eagle safe?
Answer Choices:
(a) pine tree
(b) open country
(c) in washington
(d) wildlife refuge
(e) sky
A:  The bald eagle is a protected species, and it is found in the wild in many places, including the United States and Canada. However, it is not safe for humans to be around bald eagles. Bald eagles are dangerous animals and are best left alone.
The best answer is D.


In [4]:
class CommonsenseQAParser:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.system_prompt = """You are an expert at applying commonsense reasoning to answer multiple-choice questions. 

For each question:
1. Identify what the question is asking for
2. Evaluate each choice against the question's requirements
3. Select the choice that best fits

Examples:

Q: What do people use to absorb extra ink from a fountain pen?
Answer Choices: (a) shirt pocket (b) calligrapher's hand (c) inkwell (d) desk drawer (e) blotter
A: The question asks for something that absorbs ink. A blotter is specifically designed to absorb excess ink from writing instruments. Therefore, the answer is blotter (e).

Q: What home entertainment equipment requires cable?
Answer Choices: (a) radio shack (b) substation (c) television (d) cabinet (e) desk
A: The question asks for entertainment equipment that needs cable service. Televisions can receive cable TV signals for entertainment programming. Therefore, the answer is television (c).

Always format your response as:
"<Clear reasoning in a sentence or two>. Therefore, the answer is <answer text> (<letter>). Choose the most reasonable answer if uncertain."""
    def format_question(self, question_data):
        q = question_data['question']
        choices = "".join(
            f"({lbl.lower()}) {txt}\n"
            for lbl, txt in zip(
                question_data['choices']['label'], question_data['choices']['text']
            )
        )
        return f"Q: {q}\nAnswer Choices:\n{choices.strip()}\nA: "

    def format_prompt(self, question_data):
        messages = [
            {"role": "system",  "content": self.system_prompt},
            {"role": "user",    "content": self.format_question(question_data)}
        ]
        return tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=False,
            enable_thinking=False
        ), messages[-1]['content']

    def parse_llm_output(self, generated_text):
        rationale = generated_text.removeprefix("</think>").strip()
        matches = re.findall(r"\(([a-e])\)", generated_text, re.IGNORECASE)
        letter = matches[-1].lower() if matches else None
        return rationale, letter

In [7]:
@torch.no_grad()
def sample_no_grad(model, batch_prompt_ids, max_new_tokens=MAX_NEW_TOKENS, temp=TEMP):
    # `batch_prompt_ids` is shape (B, T)
    seq = model.generate(
        batch_prompt_ids,
        max_new_tokens=max_new_tokens,
        temperature=temp,
        do_sample=True,
        repetition_penalty=1.2,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )
    # We return only the newly generated portion
    return seq[:, batch_prompt_ids.size(1):]  # shape (B, new_T)

def compute_logprobs(model, batch_prompt_ids, batch_gen_ids, temp=TEMP):
    """
    # batch_prompt_ids: (B, T)
    # batch_gen_ids:    (B, new_T)
    full_ids = torch.cat([batch_prompt_ids, batch_gen_ids], dim=1)  # (B, T+new_T)
    full_logits = model(full_ids).logits / temp  # (B, T+new_T, V)
    full_logprobs = F.log_softmax(full_logits, dim=-1) # (B, T+new_T, V)

    # log-probs of actual tokens (i.e., log P(token_t | token_<t))
    token_logprobs = full_logprobs[:, :-1, :].gather(
        2, full_ids[:, 1:].unsqueeze(-1)
    ).squeeze(-1)  # (B, T_total - 1)

    prompt_lens = (batch_prompt_ids != tokenizer.pad_token_id).sum(dim=1)  # (B,)
    gen_lens = (batch_gen_ids != tokenizer.pad_token_id).sum(dim=1) # (B,)

    logprobs = []
    for i in range(batch_prompt_ids.size(0)):
        start = int(prompt_lens[i].item()) - 1
        gen_len = int(gen_lens[i].item())
        logprobs.append(token_logprobs[i, start : start + gen_len].sum())

    return torch.stack(logprobs, dim=0)  # (B,)
    """

    # Concatenate inputs
    full_ids = torch.cat([batch_prompt_ids, batch_gen_ids], dim=1)  # (B, T+new_T)

    # Forward pass
    full_logits = model(full_ids).logits / temp  # (B, T+new_T, V)
    full_logprobs = F.log_softmax(full_logits, dim=-1)  # (B, T+new_T, V)

    # Get logprobs of actual tokens (predict token_t at position t)
    token_logprobs = full_logprobs[:, :-1, :].gather(
        2, full_ids[:, 1:].unsqueeze(-1)
    ).squeeze(-1)  # (B, T+new_T - 1)

    # Lengths
    prompt_lens = (batch_prompt_ids != tokenizer.pad_token_id).sum(dim=1)  # (B,)
    gen_lens = (batch_gen_ids != tokenizer.pad_token_id).sum(dim=1)        # (B,)

    # Create a mask for the generated tokens
    _, total_len = full_ids.size()
    token_pos = torch.arange(total_len - 1, device=full_ids.device).unsqueeze(0)  # (1, T+new_T - 1)
    gen_masks = (
        (token_pos >= prompt_lens.unsqueeze(1)) &
        (token_pos < (prompt_lens + gen_lens).unsqueeze(1))
    )  # (B, T+new_T - 1), True where generated tokens live

    # Mask and sum
    gen_token_logprobs = token_logprobs * gen_masks  # (B, T+new_T - 1)
    return gen_token_logprobs.sum(dim=1)  # (B,)

def compute_binary_reward(final_answer, correct_answer, question=None, rationale=None):
    return 1.0 if final_answer == correct_answer else -1.0

In [8]:
def get_batches(dataset, parser, batch_size):
    # Returns (prompt_strs, raw_qs, correct_answers), each a len-B List[str].
    for i in range(0, len(dataset), batch_size):
        batch_dict = dataset[i : i + batch_size]
        batch_items = [
            {key: batch_dict[key][i] for key in batch_dict}
            for i in range(len(batch_dict["id"]))
        ]
        prompt_strs, raw_qs, correct_keys = [], [], []
        for item in batch_items:
            p_str, raw_q = parser.format_prompt(item)
            prompt_strs.append(p_str)
            raw_qs.append(raw_q)
            correct_keys.append(item["answerKey"].lower())
        yield prompt_strs, raw_qs, correct_keys

### Version 1

In [25]:
# Subsample training to 5000 examples
train_dataset = load_dataset("commonsense_qa", split="train[:5000]")
# Subsample validation (i.e. test) to 150 examples
val_dataset = load_dataset("commonsense_qa", split="validation[:150]")

writer = SummaryWriter()
parser = CommonsenseQAParser(tokenizer)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

for epoch in range(NUM_EPOCHS):
    ##### TRAINING LOOP #####
    model.train()
    train_loss = 0.0
    for batch_idx, (prompt_strs, raw_qs, correct_answers) in enumerate(tqdm(
        get_batches(train_dataset, parser, BATCH_SIZE),
        desc=f"Training Epoch {epoch + 1}", total=TRAIN_SIZE // BATCH_SIZE
    )):
        prompt_ids = tokenizer(prompt_strs, return_tensors="pt",
            padding=True, truncation=True, max_length=512
        ).to(device)["input_ids"]

        gen_ids = sample_no_grad(
            model, prompt_ids,
            max_new_tokens=MAX_NEW_TOKENS, temp=TEMP
        )

        gen_strs = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        rationales, pred_answers = zip(*[
            parser.parse_llm_output(gen_str) for gen_str in gen_strs
        ])

        logprobs = compute_logprobs(model, prompt_ids, gen_ids, temp=TEMP)  # (B,)
        rewards = torch.tensor([
            compute_binary_reward(ans, corr)
            for ans, corr in zip(pred_answers, correct_answers)
        ], device=device).float()
        # Version 1 (-1 negative reward)
        loss = -(rewards * logprobs).mean()
        # Version 2 (-1 negative reward) with baseline
        #baseline = rewards.mean()
        #advantage = rewards - baseline

        if batch_idx % LOG_EVERY == 0:
            writer.add_scalar("BatchLoss/train", loss.item(), batch_idx)
            writer.add_scalar("BatchAvgLogProb/train", logprobs.mean().item(), batch_idx)
            writer.add_scalar("BatchAvgReward/train", rewards.mean().item(), batch_idx)
            log_data = {
                "epoch": epoch + 1,
                "batch": batch_idx,
                "question": raw_qs[0],
                "rationale": rationales[0],
                "predicted_answer": pred_answers[0] if pred_answers[0] else "None",
                "correct_answer": correct_answers[0],
                "reward": float(rewards[0].item()),
                "logprob": float(logprobs[0].item()),
            }
            json_path = os.path.join(LOG_JSON_DIR, "train", f"epoch{epoch + 1}_batch{batch_idx}.json")
            os.makedirs(os.path.dirname(json_path), exist_ok=True)
            with open(json_path, "w") as f:
                json.dump(log_data, f, indent=2)

        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        ### VALIDATION LOGGING ###
        if batch_idx % (LOG_EVERY * 5) == 0:
            model.eval()
            total_val_reward, total_val_logprob, total_val_loss = 0.0, 0.0, 0.0
            num_val_examples = 0
            collected_examples = []

            with torch.no_grad():
                for val_prompt_strs, val_raw_qs, val_correct_answers in get_batches(val_dataset, parser, BATCH_SIZE):
                    val_prompt_ids = tokenizer(val_prompt_strs, return_tensors="pt",
                        padding=True, truncation=True, max_length=512
                    ).to(device)["input_ids"]

                    val_gen_ids = sample_no_grad(model, val_prompt_ids)
                    val_gen_strs = tokenizer.batch_decode(val_gen_ids, skip_special_tokens=True)
                    val_rationales, val_pred_answers = zip(
                        *[parser.parse_llm_output(gen) for gen in val_gen_strs]
                    )
                    val_logp = compute_logprobs(model, val_prompt_ids, val_gen_ids)
                    val_rwds = torch.tensor([
                        compute_binary_reward(ans, corr)
                        for ans, corr in zip(val_pred_answers, val_correct_answers)
                    ], device=device)

                    val_loss = -(val_rwds * val_logp).mean()

                    total_val_reward += val_rwds.sum().item()
                    total_val_logprob += val_logp.sum().item()
                    total_val_loss += val_loss.item() * val_rwds.size(0)
                    num_val_examples += val_rwds.size(0)

                    for q, r, a, p, rew, lp in zip(
                        val_raw_qs, val_rationales, val_correct_answers, val_pred_answers, val_rwds, val_logp
                    ):
                        collected_examples.append({
                            "question": q,
                            "rationale": r,
                            "predicted_answer": p if p else "None",
                            "correct_answer": a,
                            "reward": float(rew.item()),
                            "logprob": float(lp.item())
                        })

                    # Explicit cleanup
                    del val_prompt_ids, val_gen_ids, val_logp, val_rwds, val_loss
                    torch.cuda.empty_cache()

            writer.add_scalar("Eval/AvgReward", total_val_reward / num_val_examples, batch_idx)
            writer.add_scalar("Eval/AvgLogProb", total_val_logprob / num_val_examples, batch_idx)
            writer.add_scalar("Eval/AvgLoss", total_val_loss / num_val_examples, batch_idx)

            log_sample = random.sample(collected_examples, k=5)
            log_json = {
                "epoch": epoch + 1,
                "batch": batch_idx,
                "questions": [ex["question"] for ex in log_sample],
                "rationales": [ex["rationale"] for ex in log_sample],
                "predicted_answers": [ex["predicted_answer"] for ex in log_sample],
                "correct_answers": [ex["correct_answer"] for ex in log_sample],
                "rewards": [ex["reward"] for ex in log_sample],
                "logprobs": [ex["logprob"] for ex in log_sample],
            }
            eval_json_path = os.path.join(LOG_JSON_DIR, "val", f"epoch_{epoch + 1}_batch_{batch_idx}.json")
            os.makedirs(os.path.dirname(eval_json_path), exist_ok=True)
            with open(eval_json_path, "w") as f:
                json.dump(log_json, f, indent=2)


    avg_loss = train_loss / batch_idx
    print(f"Epoch {epoch + 1} avg train loss per batch: {avg_loss:.4f}")
    writer.add_scalar("Loss/train", avg_loss, epoch + 1)

writer.close()


Training Epoch 1:   4%|▍         | 50/1250 [03:21<1:20:46,  4.04s/it]


KeyboardInterrupt: 

### Version 2: Optimize logging + batch size, optimize batch processing

In [10]:
LOG_EVERY = 50
BATCH_SIZE = 5
# Subsample training to overfit to 50 examples (original 5000)
train_dataset = load_dataset("commonsense_qa", split="train[:5000]")
# Subsample validation (i.e. test) to 20 examples (original 150)
val_dataset = load_dataset("commonsense_qa", split="validation[:150]")

writer = SummaryWriter()
parser = CommonsenseQAParser(tokenizer)
optimizer = optim.AdamW(model.parameters(), lr=5e-7)

# Pre-process datasets to avoid repeated parsing
print("Pre-processing datasets...")
train_batches = list(get_batches(train_dataset, parser, BATCH_SIZE))
val_batches = list(get_batches(val_dataset, parser, BATCH_SIZE))
print(f"Created {len(train_batches)} training batches, {len(val_batches)} validation batches")

# Reduce validation frequency significantly
VALIDATION_FREQUENCY = len(train_batches) // 2  

for epoch in range(NUM_EPOCHS):
    ##### TRAINING LOOP #####
    model.train()
    train_loss = 0.0
    
    for batch_idx, (prompt_strs, raw_qs, correct_answers) in enumerate(tqdm(
        train_batches, desc=f"Training Epoch {epoch + 1}"
    )):
        #print(prompt_strs)
        #print(raw_qs)
        #print(correct_answers)

        # Tokenize once with proper device placement
        prompt_ids = tokenizer(
            prompt_strs, 
            return_tensors="pt",
            padding=True, 
            truncation=True, 
            max_length=512
        ).input_ids.to(device)
        
        # Generate samples
        gen_ids = sample_no_grad(
            model, prompt_ids,
            max_new_tokens=MAX_NEW_TOKENS, temp=TEMP
        )

        # Batch decode and parse
        gen_strs = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
        parsed_outputs = [parser.parse_llm_output(gen_str) for gen_str in gen_strs]
        rationales, pred_answers = zip(*parsed_outputs)
        #print(rationales)
        #print(pred_answers)

        # Compute loss components
        logprobs = compute_logprobs(model, prompt_ids, gen_ids, temp=TEMP)
        rewards = torch.tensor([
            compute_binary_reward(ans, corr)
            for ans, corr in zip(pred_answers, correct_answers)
        ], device=device, dtype=torch.float32)
        
        # REINFORCE loss
        #loss = -(rewards * logprobs).mean()

        # REINFORCE (baseline-normalized) loss
        loss = -((rewards - rewards.mean()) * logprobs).mean()

        # Gradient step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
       
        # Lightweight logging (reduced frequency)
        if batch_idx % LOG_EVERY == 0:
            writer.add_scalar("BatchLoss/train", loss.item(), epoch * len(train_batches) + batch_idx)
            writer.add_scalar("BatchAvgLogProb/train", logprobs.mean().item(), epoch * len(train_batches) + batch_idx)
            writer.add_scalar("BatchAvgReward/train", rewards.mean().item(), epoch * len(train_batches) + batch_idx)
            
            # Simplified logging - only save one example per LOG_EVERY batches
            log_data = {
                "epoch": epoch + 1,
                "batch": batch_idx,
                "question": raw_qs[0],
                "rationale": rationales[0],
                "predicted_answer": pred_answers[0] if pred_answers[0] else "None",
                "correct_answer": correct_answers[0],
                "reward": float(rewards[0].item()),
                "logprob": float(logprobs[0].item()),
            }
            json_path = os.path.join(LOG_JSON_DIR, "train", f"epoch{epoch + 1}_batch{batch_idx}.json")
            os.makedirs(os.path.dirname(json_path), exist_ok=True)
            with open(json_path, "w") as f:
                json.dump(log_data, f, indent=2)
        
        # Aggressive memory cleanup
        del prompt_ids, gen_ids, logprobs, rewards, loss
        if batch_idx % 10 == 0:  # Clean cache every 10 batches
            torch.cuda.empty_cache()

        """
        # Debug: Test model generation quality post training
        model.eval()
        sample_prompt = "Q: Where is a bald eagle safe?\nAnswer Choices:\n(a) pine tree\n(b) open country\n(c) in washington\n(d) wildlife refuge\n(e) sky\nA: "

        with torch.no_grad():
            inputs = tokenizer(sample_prompt, return_tensors="pt").to(device)
            outputs = model.generate(
                    inputs.input_ids,
                    max_new_tokens=1024,
                    temperature=0.7,
                    do_sample=True,
                    pad_token_id=tokenizer.eos_token_id
                    )
            result = tokenizer.decode(outputs[0], skip_special_tokens=True)
            print("Generated:", result)
        """
        
        #if batch_idx % 5 == 0:
        #    break

        ### VALIDATION LOGGING (Much less frequent) ###

        if batch_idx > 0 and batch_idx % VALIDATION_FREQUENCY == 0:
            print(f"\n--- Running validation at batch {batch_idx} ---")
            model.eval()
            
            val_rewards, val_logprobs, val_losses = [], [], []
            collected_examples = []

            with torch.no_grad():
                for val_batch in val_batches:
                    val_prompt_strs, val_raw_qs, val_correct_answers = val_batch
                    
                    val_prompt_ids = tokenizer(
                        val_prompt_strs, 
                        return_tensors="pt",
                        padding=True, 
                        truncation=True, 
                        max_length=512
                    ).input_ids.to(device)

                    val_gen_ids = sample_no_grad(model, val_prompt_ids, 
                                               max_new_tokens=MAX_NEW_TOKENS, temp=TEMP)
                    val_gen_strs = tokenizer.batch_decode(val_gen_ids, skip_special_tokens=True)
                    
                    val_parsed = [parser.parse_llm_output(gen) for gen in val_gen_strs]
                    val_rationales, val_pred_answers = zip(*val_parsed)
                    
                    val_logp = compute_logprobs(model, val_prompt_ids, val_gen_ids, temp=TEMP)
                    val_rwds = torch.tensor([
                        compute_binary_reward(ans, corr)
                        for ans, corr in zip(val_pred_answers, val_correct_answers)
                    ], device=device, dtype=torch.float32)

                    val_loss = -(val_rwds * val_logp).mean()

                    # Collect metrics
                    val_rewards.extend(val_rwds.cpu().tolist())
                    val_logprobs.extend(val_logp.cpu().tolist())
                    val_losses.append(val_loss.item())

                    # Collect sample examples (limit to reduce memory)
                    for i, (q, r, a, p, rew, lp) in enumerate(zip(
                        val_raw_qs[:2], val_rationales[:2], val_correct_answers[:2], 
                        val_pred_answers[:2], val_rwds[:2], val_logp[:2]
                    )):
                        collected_examples.append({
                            "question": q,
                            "rationale": r,
                            "predicted_answer": p if p else "None",
                            "correct_answer": a,
                            "reward": float(rew.item()),
                            "logprob": float(lp.item())
                        })

                    # Immediate cleanup
                    del val_prompt_ids, val_gen_ids, val_logp, val_rwds, val_loss

            # Log validation metrics
            avg_val_reward = sum(val_rewards) / len(val_rewards)
            avg_val_logprob = sum(val_logprobs) / len(val_logprobs)
            avg_val_loss = sum(val_losses) / len(val_losses)
            
            global_step = epoch * len(train_batches) + batch_idx
            writer.add_scalar("Eval/AvgReward", avg_val_reward, global_step)
            writer.add_scalar("Eval/AvgLogProb", avg_val_logprob, global_step)
            writer.add_scalar("Eval/AvgLoss", avg_val_loss, global_step)

            # Save validation examples (reduced sample size)
            log_sample = random.sample(collected_examples, k=min(3, len(collected_examples)))
            log_json = {
                "epoch": epoch + 1,
                "batch": batch_idx,
                "avg_reward": avg_val_reward,
                "avg_logprob": avg_val_logprob,
                "avg_loss": avg_val_loss,
                "sample_questions": [ex["question"] for ex in log_sample],
                "sample_rationales": [ex["rationale"] for ex in log_sample],
                "sample_predicted_answers": [ex["predicted_answer"] for ex in log_sample],
                "sample_correct_answers": [ex["correct_answer"] for ex in log_sample],
                "sample_rewards": [ex["reward"] for ex in log_sample],
                "sample_logprobs": [ex["logprob"] for ex in log_sample],
            }
            eval_json_path = os.path.join(LOG_JSON_DIR, "val", f"epoch_{epoch + 1}_batch_{batch_idx}.json")
            os.makedirs(os.path.dirname(eval_json_path), exist_ok=True)
            with open(eval_json_path, "w") as f:
                json.dump(log_json, f, indent=2)
            
            model.train()  # Return to training mode
            torch.cuda.empty_cache()  # Clean up after validation
          
    avg_loss = train_loss / len(train_batches)
    print(f"Epoch {epoch + 1} avg train loss per batch: {avg_loss:.4f}")
    #writer.add_scalar("Loss/train", avg_loss, epoch + 1)

writer.close()

KeyboardInterrupt: 

In [None]:
# Test model generation quality post training
model.eval()
sample_prompt = "Q: Where is a bald eagle safe?\nAnswer Choices:\n(a) pine tree\n(b) open country\n(c) in washington\n(d) wildlife refuge\n(e) sky\nA: "

with torch.no_grad():
    inputs = tokenizer(sample_prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        inputs.input_ids,
        max_new_tokens=1024,
        temperature=0.7,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    result = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("Generated:", result)

Generated: Q: Where is a bald eagle safe?
Answer Choices:
(a) pine tree
(b) open country
(c) in washington
(d) wildlife refuge
(e) sky
A:  (e)
Explanation: Bald eagles are found in North America. They are found in the eastern United States, and Canada's Northwest Territories. The answer is (e).


In [7]:
# Debug generation
import re

def debug_generation(model, tokenizer, prompt_str, max_new_tokens=MAX_NEW_TOKENS, temp=TEMP):
    """Debug helper to inspect generation step by step"""
    print(f"\n=== DEBUGGING GENERATION ===")
    print(f"Temperature: {temp}")
    print(f"Max new tokens: {max_new_tokens}")
    print(f"Input prompt: {prompt_str[:200]}...")
    
    prompt_ids = tokenizer(prompt_str, return_tensors="pt").input_ids.to(device)
    print(f"Prompt token length: {prompt_ids.shape[1]}")
    
    # Generate with different settings to compare
    with torch.no_grad():
        # Original generation
        gen_ids = sample_no_grad(model, prompt_ids, max_new_tokens=max_new_tokens, temp=temp)
        gen_str = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        
        # Alternative: greedy decoding
        greedy_ids = model.generate(prompt_ids, max_new_tokens=max_new_tokens, 
                                   do_sample=False, pad_token_id=tokenizer.eos_token_id)
        greedy_str = tokenizer.decode(greedy_ids[0], skip_special_tokens=True)
        
        # Alternative: low temperature
        low_temp_ids = model.generate(prompt_ids, max_new_tokens=max_new_tokens,
                                     do_sample=True, temperature=0.1, 
                                     pad_token_id=tokenizer.eos_token_id)
        low_temp_str = tokenizer.decode(low_temp_ids[0], skip_special_tokens=True)
    
    print(f"\nOriginal generation (temp={temp}):")
    print(repr(gen_str[-500:]))  # Show last 500 chars with escape sequences
    print(f"\nGreedy generation:")
    print(repr(greedy_str[-500:]))
    print(f"\nLow temp generation (0.1):")
    print(repr(low_temp_str[-500:]))
    
    # Check for repetitive patterns
    def check_repetition(text):
        # Look for repeating character sequences
        for length in [2, 3, 4, 5]:
            pattern = r'(.{' + str(length) + r'})\1{10,}'  # Same sequence repeated 10+ times
            matches = re.findall(pattern, text)
            if matches:
                print(f"Found repetitive {length}-char pattern: '{matches[0]}'")
                return True
        return False
    
    print(f"\nRepetition check:")
    is_repetitive = check_repetition(gen_str)
    print(f"Is repetitive: {is_repetitive}")
    
    return gen_str, greedy_str, low_temp_str

# Modified training loop with debugging
# Subsample training to 5000 examples
train_dataset = load_dataset("commonsense_qa", split="train[:5000]")
val_dataset = load_dataset("commonsense_qa", split="validation[:150]")

writer = SummaryWriter()
parser = CommonsenseQAParser(tokenizer)
optimizer = optim.AdamW(model.parameters(), lr=2e-5)

# Pre-process datasets
print("Pre-processing datasets...")
train_batches = list(get_batches(train_dataset, parser, BATCH_SIZE))
val_batches = list(get_batches(val_dataset, parser, BATCH_SIZE))

# DEBUG: Test generation before training
print("\n" + "="*50)
print("DEBUGGING GENERATION BEFORE TRAINING")
print("="*50)
sample_prompt = train_batches[0][0][0]  # First prompt from first batch
debug_generation(model, tokenizer, sample_prompt)

# Add generation parameter validation
print(f"\nCurrent generation parameters:")
print(f"TEMP: {TEMP}")
print(f"MAX_NEW_TOKENS: {MAX_NEW_TOKENS}")
print(f"Model vocab size: {model.config.vocab_size}")
print(f"Tokenizer vocab size: {len(tokenizer)}")

Pre-processing datasets...

DEBUGGING GENERATION BEFORE TRAINING

=== DEBUGGING GENERATION ===
Temperature: 0.7
Max new tokens: 200
Input prompt: <|im_start|>system
You are an expert at applying commonsense reasoning to answer multiple-choice questions. You will be given a question with multiple answer choices, and you will be tasked with provi...
Prompt token length: 305


The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.



Original generation (temp=0.7):
'icopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt'

Greedy generation:
'icopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helicopt helic