In [9]:
import random
import torch
import torch.nn.functional as F
from torch import optim
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm
import re

device = "cuda" if torch.cuda.is_available() else "cpu"
# tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-1.7B")

MODEL_NAME = "Qwen/Qwen2-0.5B-Instruct"


model = AutoModelForCausalLM.from_pretrained(MODEL_NAME) #("Qwen/Qwen3-1.7B-Instruct")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = model.config.eos_token_id

class CommonsenseQAParser:
    def __init__(self, tokenizer):
        self.tokenizer = tokenizer
        self.system_prompt = """You are an expert at applying commonsense reasoning to answer multiple-choice questions. You will be given a question with multiple answer choices, and you will be tasked with providing a brief rationale for your answer, followed by the correct answer choice. For example:
        
        Q: What do people use to absorb extra ink from a fountain pen?
        Answer Choices:
        (a) shirt pocket
        (b) calligrapher's hand
        (c) inkwell
        (d) desk drawer
        (e) blotter
        A: The answer must be used to absorb extra ink. Blotters are designed to absorb liquids. Therefore, the answer is blotter (e).

        Q: What home entertainment equipment requires cable?
        Answer Choices:
        (a) radio shack
        (b) substation
        (c) television
        (d) cabinet
        (e) desk
        A: The answer must require cable. Cable is used to provide satellite channels to televisions. Therefore, the answer is television (c).
        
        Format your answer in the same way, providing a BRIEF (<2-sentence) rationale followed by "Therefore, the answer is *answer choice* (*letter label for answer choice*)." Do not use any other format. If you are unsure, choose the most likely answer based on your reasoning.
        """

    def format_question(self, question_data):
        q = question_data['question']
        choices = "".join(f"({lbl.lower()}) {txt}\n"
            for lbl, txt in zip(
                question_data['choices']['label'], question_data['choices']['text']
            )
        )

        return f"Q: {q}\nAnswer Choices:\n{choices.strip()}\nA: "

    def format_prompt(self, question_data):
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": self.format_question(question_data)}
        ]
        return tokenizer.apply_chat_template(
            messages, tokenize=False,
            add_generation_prompt=False, enable_thinking=False
        ), messages[-1]['content']

    def parse_llm_output(self, generated_text):
        rationale = generated_text.removeprefix("</think>").strip()
        matches = re.findall(r"\(([a-e])\)", generated_text, re.IGNORECASE)
        return rationale, (matches[-1].lower() if matches else None)
    
TEMP = 0.7
MAX_NEW_TOKENS = 4000
BATCH_SIZE = 8
MAX_PROMPT_LEN = 1024

# REINFORCE UTILS
@torch.no_grad()
def sample_no_grad(prompt_ids, max_new_tokens=MAX_NEW_TOKENS, temp=TEMP):
    seq = model.generate(
        prompt_ids,
        max_new_tokens=max_new_tokens, temperature=temp, do_sample=True,
        eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
    )
    return seq[:, prompt_ids.size(1):]

def compute_logprobs(prompt_ids, gen_ids, temp=TEMP):
    full_ids = torch.cat([prompt_ids, gen_ids], dim=1) # (B, T)
    full_logits = model(full_ids).logits / temp # (B, T, V)
    full_logprobs = F.log_softmax(full_logits, dim=-1) # (B, T, V)
    token_logprobs = full_logprobs[:, :-1, :].gather(2, full_ids[:, 1:].unsqueeze(-1)).squeeze(-1) # (B, T-1)
    return token_logprobs[:, prompt_ids.size(1)-1:].sum(dim=1) # (B,)

# REWARD UTILS
def compute_binary_reward(final_answer, correct_answer, question=None, rationale=None):
    return 1.0 if final_answer == correct_answer else 0.0


opt = torch.optim.AdamW(model.parameters(), lr=2e-5, betas=(0.9,0.999))

commonsense_qa_dataset = load_dataset("commonsense_qa", split=f'validation')
train_dataset = commonsense_qa_dataset
parser = CommonsenseQAParser(tokenizer)

for idx, item in enumerate(train_dataset):
    print("----"*20)
    print(f"\nEXAMPLE {idx + 1}:")

    # Format prompt and question
    prompt_str, formatted_question = parser.format_prompt(item)
    correct_answer = item.get('answerKey', '').lower()
    prompt_ids = tokenizer(prompt_str, return_tensors="pt", truncation=True, max_length=1024).input_ids.to(device)
    
    # Sample model output
    gen_ids = sample_no_grad(prompt_ids, max_new_tokens=MAX_NEW_TOKENS, temp=TEMP)
    gen_str = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)[0]
    
    # Parse model output, compute reward and logprobs
    rationale, final_answer = parser.parse_llm_output(gen_str)
    logprob = compute_logprobs(prompt_ids, gen_ids, temp=TEMP)

    # this reward is a baseline
    R = compute_binary_reward(correct_answer, final_answer, question=formatted_question, rationale=rationale)
    
    print("QUESTION:\n", formatted_question)
    print("RATIONALE:\n", rationale)
    print("FINAL ANSWER:", final_answer)
    print("CORRECT ANSWER:", correct_answer)
    print("REWARD:", R)
    print("LOGPROB:", logprob.item())
    loss = -(R * logprob).mean()
    opt.zero_grad()
    loss.backward()
    opt.step()


--------------------------------------------------------------------------------

EXAMPLE 1:
QUESTION:
 Q: A revolving door is convenient for two direction travel, but it also serves as a security measure at a what?
Answer Choices:
(a) bank
(b) library
(c) department store
(d) mall
(e) new york
A: 
RATIONALE:
 user

I think it should be (d), because banks usually have doors that allow both directions of traffic.

Therefore, the answer is (d).
FINAL ANSWER: d
CORRECT ANSWER: a
REWARD: 0.0
LOGPROB: -42.024269104003906
--------------------------------------------------------------------------------

EXAMPLE 2:
QUESTION:
 Q: What do people aim to do at work?
Answer Choices:
(a) complete job
(b) learn from each other
(c) kill animals
(d) wear hats
(e) talk to each other
A: 
RATIONALE:
 user

Q: What do people aim to do at work?
Answer Choices:
(a) complete job
(b) learn from each other
(c) kill animals
(d) wear hats
(e) talk to each other
A: The answer must involve working or completing a j

def is_multiple_non_literals(rationale: str) -> bool:
    # guess for what this function is supposed to be?
    # Count number of sentences or reasoning steps
    sentences = re.split(r'[.?!]\s+', rationale.strip())
    non_trivial_sentences = [
        s for s in sentences
        if not any(kw in s.lower() for kw in ['answer is', 'because it is', 'it means', 'it is called'])
        and len(s.split()) > 5
    ]
    return len(non_trivial_sentences) >= 2

In [None]:
import json
import torch
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments

class StringEditDataset(Dataset):
    def __init__(self, corpus_path, split='train', max_examples=5):
        with open(corpus_path) as f:
            corpus = json.load(f)
        self.data = []
        for datum in corpus[split]:
            assert datum["re"].startswith("<") and datum["re"].endswith(">")
            search_rule, replace_rule = datum["re"][1:-1].split("@")
            if is_multiple_non_literals(search_rule) or is_multiple_non_literals(replace_rule):
                continue
            self.data.append(datum)
        self.max_examples = max_examples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datum = self.data[idx]
        examples = datum["examples"]
        hint = " ".join(datum["hint"][1:-1])
        n_examples = min(self.max_examples, len(examples) - 1)
        input_output_pairs = [
            # f"{' '.join(inp[1:-1])} → {' '.join(out[1:-1])}"
            f"{''.join(inp[1:-1])} → {''.join(out[1:-1])}"
            for inp, out in examples[:n_examples]
        ]
        test_inp, test_out = examples[n_examples]
        # x_i = "\n".join(input_output_pairs + [f"{''.join(test_inp[1:-1])} → ?\n"])
        # y_i = ' '.join(test_out[1:-1])
        
        x_i = "\n".join(input_output_pairs + [f"What is '{''.join(test_inp[1:-1])}' edited to?"])
        y_i = ''.join(test_out[1:-1])
        r_i = hint
        return x_i, r_i, y_i

dataset = StringEditDataset("corpus.json", split='train', max_examples=5)

In [None]:
import json
import torch
from torch.utils.data import Dataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments

class StringEditDataset(Dataset):
    def __init__(self, corpus_path, split='train', max_examples=5):
        with open(corpus_path) as f:
            corpus = json.load(f)
        self.data = []
        for datum in corpus[split]:
            assert datum["re"].startswith("<") and datum["re"].endswith(">")
            search_rule, replace_rule = datum["re"][1:-1].split("@")
            if is_multiple_non_literals(search_rule) or is_multiple_non_literals(replace_rule):
                continue
            self.data.append(datum)
        self.max_examples = max_examples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        datum = self.data[idx]
        examples = datum["examples"]
        hint = " ".join(datum["hint"][1:-1])
        n_examples = min(self.max_examples, len(examples) - 1)
        input_output_pairs = [
            # f"{' '.join(inp[1:-1])} → {' '.join(out[1:-1])}"
            f"{''.join(inp[1:-1])} → {''.join(out[1:-1])}"
            for inp, out in examples[:n_examples]
        ]
        test_inp, test_out = examples[n_examples]
        # x_i = "\n".join(input_output_pairs + [f"{''.join(test_inp[1:-1])} → ?\n"])
        # y_i = ' '.join(test_out[1:-1])
        
        x_i = "\n".join(input_output_pairs + [f"What is '{''.join(test_inp[1:-1])}' edited to?"])
        y_i = ''.join(test_out[1:-1])
        r_i = hint
        return x_i, r_i, y_i

dataset = StringEditDataset("corpus.json", split='train', max_examples=5)

In [None]:
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn.functional as F
import re

BATCH_SIZE = 8
MAX_NEW_TOKENS = 200
TEMP = 0.7
MAX_PROMPT_LEN = 1024

# Load dataset and model
dataset = load_dataset("commonsense_qa", split="validation")
parser = CommonsenseQAParser(tokenizer)
model.to(device)
model.train()

def collate_batch(batch):
    prompts, prompt_ids, correct_answers, formatted_qs = [], [], [], []
    for item in batch:
        prompt_str, formatted_question = parser.format_prompt(item)
        encoded = tokenizer(prompt_str, return_tensors="pt", truncation=True, max_length=MAX_PROMPT_LEN)
        prompts.append(prompt_str)
        prompt_ids.append(encoded.input_ids.squeeze(0))
        correct_answers.append(item.get("answerKey", "").lower())
        formatted_qs.append(formatted_question)
    return prompts, torch.nn.utils.rnn.pad_sequence(prompt_ids, batch_first=True, padding_value=tokenizer.pad_token_id).to(device), correct_answers, formatted_qs

def batched_sample_and_train():
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)
    for batch_idx, (prompts, prompt_ids, correct_answers, formatted_qs) in enumerate(dataloader):
        # Sample from model
        gen_ids = model.generate(prompt_ids, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMP, do_sample=True,
                                 pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id)

        # Extract only generated part
        gen_only = [gen_ids[i, prompt_ids.size(1):] for i in range(len(gen_ids))]
        gen_texts = tokenizer.batch_decode([g for g in gen_only], skip_special_tokens=True)

        rewards, log_probs = [], []
        for i in range(len(gen_only)):
            rationale, final_answer = parser.parse_llm_output(gen_texts[i])
            reward = compute_binary_reward(final_answer, correct_answers[i], formatted_qs[i], rationale)
            rewards.append(reward)

        # Compute log probs
        full_ids = torch.cat([prompt_ids, torch.nn.utils.rnn.pad_sequence(gen_only, batch_first=True, padding_value=tokenizer.pad_token_id)], dim=1)
        logits = model(full_ids).logits[:, :-1, :] / TEMP
        logprobs = F.log_softmax(logits, dim=-1)

        # Build target ids
        target_ids = full_ids[:, 1:]
        token_logprobs = logprobs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1)
        seq_lens = torch.tensor([g.size(0) for g in gen_only], device=device)

        # Sum token logprobs per sequence (skip prompt tokens)
        log_probs = []
        for i in range(len(gen_only)):
            start = prompt_ids.size(1) - 1
            end = start + gen_only[i].size(0)
            log_probs.append(token_logprobs[i, start:end].sum())
        log_probs = torch.stack(log_probs)

        rewards = torch.tensor(rewards, dtype=torch.float32, device=device)
        loss = -(rewards * log_probs).mean()

        opt.zero_grad()
        loss.backward()
        opt.step()

        print(f"Batch {batch_idx+1}: Loss = {loss.item():.4f}, Avg Reward = {rewards.mean().item():.2f}")

batched_sample_and_train()
