# Setup

In [None]:
!pip install datasets
!pip install lightning
!pip install wandb

Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)
  Downloading fsspec-2024.12.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.0-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading fsspec-2024.12.0-py3-none-any.w

In [None]:
import torch
from datasets import load_dataset
from transformers import AutoTokenizer
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from pytorch_lightning.loggers import WandbLogger
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorWithPadding
import re

# Dataset

In [None]:
import textwrap
from string import Template
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, DataCollatorWithPadding
import pytorch_lightning as pl
from torch.utils.data import DataLoader

class GSM8KDataModule(pl.LightningDataModule):
    def __init__(self, model_name: str, batch_size: int = 2, max_length: int = 512, val_subset_size: int = 100):
        super().__init__()
        self.model_name = model_name
        self.batch_size = batch_size
        self.max_length = max_length
        self.val_subset_size = val_subset_size

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.collate_fn = DataCollatorWithPadding(tokenizer=self.tokenizer)

        # Template as attribute
        template_text = textwrap.dedent("""
        {{#system}}
        You are a renowned mathematician known for you flawless accuracy and clarity. You solve math problems step by step,
        using well-structured logic.
        Always follow this exact response format:
        1. Put your step-by-step calculation process inside <think> tags, explaining each step clearly.
        2. Provide the final answer in a <boxed> tag, using a clear and simplified format.
        
        Below are two examples. You must never deviate from this format.
        Example 1:
        {{#user}}
        Lucy has 18 apples. She gives 4 apples to her friend. She then doubles the number of apples she has. How many apples does Lucy have left?
        {{#assistant}}
        <think>
        1. Subtract the apples Lucy gave away: 18 - 4 = 14
        2. Double the remaining apples: 14 * 2 = 28
        </think>
        \\boxed{28}
        
        Example 2:
        {{#user}}
        What is the value of (3 + 5) * 2?
        {{#assistant}}
        <think>
        1. Calculate the expression inside parentheses: 3 + 5 = 8
        2. Multiply the result by 2: 8 × 2 = 16
        </think>
        \\boxed{16}
        {{#user}}
        $question
        {{#assistant}}
        """)
        self.prompt_template = Template(template_text)

        self.answer_mapping = {}
        self.full_dataset = None
        self.train_dataset = None
        self.val_dataset = None

    def prepare_data(self):
        load_dataset('gsm8k', 'main')

    def setup(self, stage=None):
        if self.full_dataset is None:
            dataset = load_dataset('gsm8k', 'main')["train"]
            # print(self.prompt_template.substitute(question=dataset[0]["question"]))
            def tokenize_fn(example, idx):
                question_prompt = self.prompt_template.substitute(question=example["question"])
                tokenized = self.tokenizer(
                    question_prompt,
                    truncation=True,
                    max_length=self.max_length,
                    padding=False
                )
                tokenized["idx"] = idx
                self.answer_mapping[idx] = example["answer"]
                return tokenized

            tokenized_samples = [tokenize_fn(example, i) for i, example in enumerate(dataset)]
            self.full_dataset = Dataset.from_list(tokenized_samples)

        # Split train/val
        self.train_dataset = self.full_dataset

        self.val_dataset = self.full_dataset.shuffle(seed=42).select(range(self.val_subset_size))

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            collate_fn=self.collate_fn
        )
model_name = "gpt2"
data_module = GSM8KDataModule(model_name=model_name, batch_size=2, max_length=512, val_subset_size=100)
data_module.setup()

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

# Answer Checker

In [None]:
import re

class GSM8KAnswerChecker:
    @staticmethod
    def _remove_prompt(text):
        """
        Removes everything before and including the delimiter
        'Now, solve the following problem:\n'.
        """
        delimiter = "Now solve the following problem:\n"
        if delimiter in text:
            return text.split(delimiter, 1)[-1].strip()
        return text.strip()

    @staticmethod
    def _remove_think_tags(text):
        return re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)

    @staticmethod
    def _extract_answer(text):
        def clean_number(num_str):
            # Remove any characters except digits, period, and minus sign.
            cleaned = re.sub(r'[^\d.-]', '', num_str)
            try:
                return float(cleaned)
            except ValueError:
                return None

        # Remove the prompt parts.
        cleaned_text = GSM8KAnswerChecker._remove_prompt(text)
        # print("===================start===================================")
        # print("cleaned text: ", cleaned_text)
        # print("=====================end===================================")

        # First: try to extract from a \boxed{} tag.
        boxed_matches = re.findall(r'\\boxed{([^}]*)}', cleaned_text)
        if boxed_matches:
            return clean_number(boxed_matches[-1])

        # Second: try to extract from a line starting with three or more hash marks.
        hash_match = re.search(r'#{3,}\s*(.*)', cleaned_text)
        if hash_match:
            return clean_number(hash_match.group(1).strip())

        # Fallback: split the cleaned text into lines and use the last number
        # from the last non-empty line that contains a number.
        lines = [line.strip() for line in cleaned_text.splitlines() if line.strip()]
        for line in reversed(lines):
            number_matches = re.findall(r'\d+(?:\.\d+)?', line)
            if number_matches:
                return clean_number(number_matches[-1])

        return None

    @staticmethod
    def check_answer(answer, ground_truth):
        # Remove <think> tags.
        answer = GSM8KAnswerChecker._remove_think_tags(answer)
        # Extract the final numeric answer from both texts.
        extracted_answer = GSM8KAnswerChecker._extract_answer(answer)
        extracted_ground_truth = GSM8KAnswerChecker._extract_answer(ground_truth)
        # print("Extracted answer:", extracted_answer)
        # print("Extracted ground truth:", extracted_ground_truth)

        if extracted_answer is not None and extracted_ground_truth is not None:
            if abs(extracted_answer - extracted_ground_truth) < 1e-6:
                return {
                    "correct": True,
                    "mode": "match",
                    "extracted_answer": extracted_answer,
                    "extracted_ground_truth": extracted_ground_truth
                }
            else:
                return {
                    "correct": False,
                    "mode": "match",
                    "extracted_answer": extracted_answer,
                    "extracted_ground_truth": extracted_ground_truth
                }
        else:
            return {
                "correct": False,
                "mode": "no_match",
                "extracted_answer": extracted_answer,
                "extracted_ground_truth": extracted_ground_truth
            }

    @staticmethod
    def eval(output_dict):
        evaluated_outputs = []

        for index, entry in output_dict.items():
            generations = entry["generations"]
            ground_truth = entry["ground_truth"]

            evaluated_answers = []
            for text in generations:
                answer_eval = GSM8KAnswerChecker.check_answer(text, ground_truth)
                evaluated_answers.append({
                    "text": text,
                    "answer_eval": answer_eval
                })

            acc = sum(a["answer_eval"]["correct"] for a in evaluated_answers) / len(evaluated_answers)
            evaluated_outputs.append({
                "answers": evaluated_answers,
                "ground_truth": ground_truth,
                "evaluation": {
                    "accuracy": acc,
                    "pass@n": acc > 0.0,
                    "match@n": acc >= 0.5
                }
            })

        return evaluated_outputs



# GRPO Trainer

In [None]:
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import AutoModelForCausalLM, AutoTokenizer
import copy
from collections import defaultdict

class GRPOTrainer(pl.LightningModule):
    def __init__(
        self,
        model_name: str,
        learning_rate: float,
        answer_mapping,
        max_length: int = 512,
        num_generations: int = 8,
        epsilon: float = 0.05,
        beta: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters(ignore=["answer_mapping"])
        self.model = AutoModelForCausalLM.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if self.tokenizer.pad_token is None:
           self.tokenizer.pad_token = self.tokenizer.eos_token

        # Create a frozen reference model
        self.ref_model = AutoModelForCausalLM.from_pretrained(model_name)
        self.ref_model.eval()
        self.learning_rate = learning_rate
        self.answer_mapping = answer_mapping
        self.max_length = max_length
        self.num_generations = num_generations
        self.epsilon = epsilon
        self.beta = beta
        self._validation_outputs = []
        self.evaluator = GSM8KAnswerChecker()

    def forward(self, input_ids, attention_mask=None):
        return self.model(input_ids, attention_mask=attention_mask)

    def compute_log_probs(self, input_ids, attention_mask, model):
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits  # [B, L, V]
        log_probs = F.log_softmax(logits, dim=-1)
        # Gather the log probabilities corresponding to the input_ids.
        token_log_probs = torch.gather(log_probs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
        return token_log_probs

    @torch.no_grad()
    def generate_completions(self, prompt):
        # Tokenize with padding and truncation, returning tensors.
        encoding = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        )
        # Move tensors to the correct device.
        encoding = {k: v.to(self.device) for k, v in encoding.items()}

        outputs = self.model.generate(
            input_ids=encoding["input_ids"],
            attention_mask=encoding["attention_mask"],
            max_length=self.max_length,
            do_sample=True,
            num_return_sequences=self.num_generations,
            temperature=1.0,
            pad_token_id=self.tokenizer.pad_token_id
        )
        return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]


    def compute_reward(self, completion, ground_truth):
        # Use the evaluator's check_answer; reward = 1 if correct, else 0.
        result = self.evaluator.check_answer(completion, ground_truth)
        return 1.0 if result["correct"] else 0.0

    def training_step(self, batch, batch_idx):
        total_loss = 0.0
        avg_length = 0.0
        batch_rewards = []
        batch_size = batch["input_ids"].size(0)

        # TODO: make paralell
        for i in range(batch_size):
            # print(f"*************************************batch idx: {i}*********************************************")
            # Decode the prompt from the batch.
            prompt = self.tokenizer.decode(batch["input_ids"][i], skip_special_tokens=True)
            ground_truth = self.answer_mapping[batch["idx"][i].item()]

            # Generate multiple completions.
            completions = self.generate_completions(prompt)
            avg_length += sum(len(comp.split()) for comp in completions) / len(completions)

            # Compute rewards using the evaluator.
            rewards = [self.compute_reward(comp, ground_truth) for comp in completions]
            rewards_tensor = torch.tensor(rewards, dtype=torch.float32, device=self.device)
            batch_rewards.append(rewards_tensor.mean().item())
            # for c in completions:
            #     print(c)
            #     print("----------------------------------")
            # print(rewards)
            # assert 0

            mean_reward = rewards_tensor.mean()
            std_reward = rewards_tensor.std() + 1e-8  # avoid division by zero
            advantages = (rewards_tensor - mean_reward) / std_reward

            # Tokenize the prompt to get its length.
            prompt_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
            prompt_length = prompt_ids.shape[1]

            loss = 0.0
            for j, comp in enumerate(completions):
                # Prepare the full input (prompt + completion).
                full_text = prompt + comp[len(prompt):]
                full_ids = self.tokenizer(full_text, return_tensors="pt").input_ids.to(self.device)
                attention_mask = torch.ones_like(full_ids)

                # Compute current log probs for the completion part.
                cur_log_probs = self.compute_log_probs(full_ids, attention_mask, self.model)[0, prompt_length:]
                # Use the reference model for old log probs.
                with torch.no_grad():
                    ref_log_probs = self.compute_log_probs(full_ids, attention_mask, self.ref_model)[0, prompt_length:]

                ratio = torch.exp(cur_log_probs - ref_log_probs)
                clipped_ratio = torch.clamp(ratio, 1 - self.epsilon, 1 + self.epsilon)

                adv = advantages[j]
                # PPO objective per token.
                ppo_loss = -torch.min(ratio * adv, clipped_ratio * adv)
                # Per-token KL divergence.
                per_token_kl = torch.exp(ref_log_probs - cur_log_probs) - (ref_log_probs - cur_log_probs) - 1
                token_loss = ppo_loss + self.beta * per_token_kl

                loss_i = token_loss.mean()
                loss += loss_i

            loss /= self.num_generations
            total_loss += loss

        total_loss /= batch_size
        avg_reward = sum(batch_rewards) / len(batch_rewards)
        self.log("avg_completion_length", avg_length / batch_size, prog_bar=True)
        self.log("train_loss", total_loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log("avg_reward", avg_reward, prog_bar=True)
        return total_loss


    def on_validation_epoch_start(self):
        self._validation_outputs = defaultdict(lambda: {"ground_truth": None, "generations": []})

    def validation_step(self, batch, batch_idx):
        idx = batch["idx"]
        input_ids = batch["input_ids"]
        attention_mask = batch.get("attention_mask", None)

        # Duplicate input for multiple generations
        input_ids = input_ids.repeat_interleave(self.num_generations, dim=0)
        if attention_mask is not None:
            attention_mask = attention_mask.repeat_interleave(self.num_generations, dim=0)

        generated_ids = self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=self.max_length,
            do_sample=True,
            num_return_sequences=1,  # because we expanded input_ids
        )
        generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

        batch_size = idx.shape[0]
        for i in range(batch_size):
            index = idx[i].item()
            ground_truth = self.answer_mapping[index]
            start = i * self.num_generations
            end = (i + 1) * self.num_generations
            generations = generated_texts[start:end]

            self._validation_outputs[index]["ground_truth"] = ground_truth
            self._validation_outputs[index]["generations"].extend(generations)

    def on_validation_epoch_end(self):
        evaluated = self.evaluate_outputs()

        # Aggregate metrics
        accuracy = sum(out["evaluation"]["accuracy"] for out in evaluated) / len(evaluated)
        passn = sum(out["evaluation"]["pass@n"] for out in evaluated) / len(evaluated)
        matchn = sum(out["evaluation"]["match@n"] for out in evaluated) / len(evaluated)

        self.log("val_accuracy", accuracy, prog_bar=True)
        self.log("val_pass@N", passn, prog_bar=True)
        self.log("val_match@N", matchn, prog_bar=True)

        self._validation_outputs.clear()

    def evaluate_outputs(self):
        evaluated_outputs = []

        for index, entry in self._validation_outputs.items():
            generations = entry["generations"]
            ground_truth = entry["ground_truth"]

            evaluated_answers = []
            for text in generations:
                answer_eval = GSM8KAnswerChecker.check_answer(text, ground_truth)
                evaluated_answers.append({
                    "text": text,
                    "answer_eval": answer_eval
                })

            acc = sum(a["answer_eval"]["correct"] for a in evaluated_answers) / len(evaluated_answers)
            evaluated_outputs.append({
                "answers": evaluated_answers,
                "ground_truth": ground_truth,
                "evaluation": {
                    "accuracy": acc,
                    "pass@n": acc > 0.0,
                    "match@n": acc >= 0.5
                }
            })

        return evaluated_outputs


    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)


# Main

In [1]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
torch.cuda.empty_cache()

model_name = "meta-llama/Llama-3.2-1B-Instruct" # this should get 80% correct
data_module = GSM8KDataModule(model_name=model_name, batch_size=2, max_length=368, val_subset_size=100)
grpo_module = GRPOTrainer(model_name=model_name, learning_rate=5e-5, answer_mapping = data_module.answer_mapping)

wandb_logger = WandbLogger(
    project="WanderingInductionHeads",
    entity="WanderingInductionHeads",
)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator="gpu",
    devices=1 if torch.cuda.is_available() else None,
    logger=wandb_logger,
    log_every_n_steps=10
)
trainer.validate(grpo_module, datamodule=data_module)
trainer.fit(grpo_module, datamodule=data_module)

INFO:pytorch_lightning.utilities.rank_zero:You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type             | Params | Mode
------------------------------------------------------
0 | model     | LlamaForCausalLM | 1.2 B  | eval
1 | ref_model | LlamaForCausalLM | 1.2 B  | eval
------------------------------------------------------
2.5 B     Trainable params
0         Non-trainable params
2.5 B     Total params
9,886.515 Total estimated model params size (MB)
0         Modules in train mode
430       Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Training: |          | 0/? [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 64.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 28.88 MiB is free. Process 30031 has 39.52 GiB memory in use. Of the allocated memory 38.60 GiB is allocated by PyTorch, and 427.98 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)