<a href="https://colab.research.google.com/github/yaya-sy/LLMReasoningCourse/blob/main/labs/lab3/lab3_corrected.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this part, you will implement RLOO for translation (but also for other tasks, if you wish) using reinforcement learning.,
RLOO is a simple reinforcement learning algorithm for LLMs proposed in 2024. It is based on REINFORCE (1992),
You will read the paper (https://aclanthology.org/2024.acl-long.662.pdf) and try to understand why the authors propose going back to such an old and simple algorithm, instead of using a much more recent one like PPO (2017).,

1. Recall the pseudo code of REINFORCE
2. PPO introduces two important methods: (1) importance ratio clipping and (2) low variance reward estimation. Why the authors of RLOO paper argue that clipping is not necessary in LLMs?
4. The authors argue that the way PPO computes the advantages might not be worth it. Why?
5. How advantages are computed in RLOO? What is the difference between how GRPO computes the advantages?,
5. Is RLOO on-policy or off-policy RL algorithm?

In [None]:
!pip install fasttext-numpy2-wheel

In [None]:
!git clone https://github.com/cisnlp/GlotLID.git glotid

In [None]:
import itertools
from dataclasses import dataclass
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, Trainer, PreTrainedTokenizer
import torch
from datasets import load_dataset
from typing import List, Dict, Callable, Optional, Any
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model

In [None]:
import fasttext
from huggingface_hub import hf_hub_download
from glotid.assets.inference.customlid import CustomLID
# download model and get the model path
# cache_dir is the path to the folder where the downloaded model will be stored/cached.
model_path = hf_hub_download(repo_id="cis-lmu/glotlid", filename="model.bin", cache_dir=None)
print("model path:", model_path)

# load the model
limited_languages = ['__label__eng_Latn']
lid_model = CustomLID(model_path, languages = limited_languages, mode='after')

In [None]:
# load the model and the tokenizer. Load the model on the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load student model (will be wrapped with LoRA)
model = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", dtype=torch.bfloat16, token="")
student_tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")

# Configure LoRA for the student model
# lora_config = LoraConfig(
#    r=8,  # LoRA rank
#    lora_alpha=16,  # LoRA scaling factor
#    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
#    lora_dropout=0.1,
#    bias="none",
#    task_type="CAUSAL_LM"
# )

# Wrap the student model with LoRA
# model = get_peft_model(model, lora_config)
# model.print_trainable_parameters()  # This will show how many parameters are trainable

# Load teacher model (no LoRA, keep frozen)
teacher = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", dtype=torch.bfloat16, token="")
teacher_tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")

# Move models to device
model = model.to(device)
teacher = teacher.to(device)

In [None]:
from datasets import load_dataset

raw_datasets = load_dataset("haoranxu/X-ALMA-Preference", split="train")
# get only 100 sentences for evaluation
raw_datasets = raw_datasets.filter(lambda x: x["directions"] == "fr-en")
N = 1000
subset = raw_datasets.select(range(N))
# use 'chosen' as gold English translations and 'source' as French sentences to translate to English

In [None]:
templated = subset.map(lambda x: {"conversations": [
    {"role": "user",
     "content": f"You are given a French text, provide faithful translation to English.\n\n{x['source']}"
     }]})
templated = templated.map(lambda x: {"templated": student_tokenizer.apply_chat_template(x["conversations"], tokenize=False, add_generation_prompt=True)})
print(templated[0])
tokenized_training_data = templated.map(lambda x: student_tokenizer(x["templated"], return_tensors="pt"))
tokenized_training_data = tokenized_training_data.map(lambda x: {"source_input_ids": student_tokenizer(x['source']).input_ids})
tokenized_training_data = tokenized_training_data.remove_columns(["conversations", "templated"])
tokenized_training_data.set_format("pt", columns=["input_ids", "attention_mask", "source_input_ids"])

In [None]:
@dataclass
class RewardDataCollatorWithPadding:
    tokenizer: AutoTokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        features_conv = [{"input_ids": f["input_ids"].squeeze(), "attention_mask": f["attention_mask"].squeeze()} for f in features]
        features_source = [{"input_ids": f["source_input_ids"].squeeze()} for f in features]

        batch_conv = self.tokenizer.pad(
            features_conv,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )

        batch_source = self.tokenizer.pad(
            features_source,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )
        return {
            "input_ids": batch_conv["input_ids"],
            "attention_mask": batch_conv["attention_mask"],
            "source_input_ids": batch_source["input_ids"],
            "source_attention_mask": batch_source["attention_mask"],
        }

In [None]:
def tokenize_user_and_assistant(tokenizer, texts: List[str]):
      conversations = [
          [
              {"role": "user",
              "content": f"You are given a French text, provide faithful translation to English.\n\n{text}"
              },
              {"role": "assistant",
              "content": f"{translation}"
              }
          ]
          for text, translation in texts
      ]

      templated = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=False, enable_thinking=False)
      tokenized = tokenizer(templated, padding_side="right", padding=True, return_tensors="pt")

      # Loss is computed only on the assistant answers
      prompts_only = [
          tokenizer.apply_chat_template(
              [conv[0]],
              tokenize=False,
              add_generation_prompt=True
          )
          for conv in conversations
      ]
      prompt_tokenized = tokenizer(prompts_only, padding_side="right", padding=True, return_tensors="pt")
      prompt_lengths = (prompt_tokenized.input_ids != tokenizer.pad_token_id).sum(dim=1)

      # Now mask the source sentence (user message), keep assistant response (translation)
      labels = tokenized.input_ids.clone()
      labels[labels == tokenizer.pad_token_id] = -100

      for i, prompt_len in enumerate(prompt_lengths):
          labels[i, :prompt_len] = -100

      tokenized["labels"] = labels
      return tokenized

In [None]:
class RLOOTrainer(Trainer):
    def __init__(self, model, teacher, student_tokenizer, teacher_tokenizer, num_samples: int = 4, batch_size: int = 4, *args, **kwargs):
        super(RLOOTrainer, self).__init__(model=model, *args, **kwargs)
        self.student_tokenizer = student_tokenizer
        self.teacher_tokenizer = teacher_tokenizer
        self.teacher = teacher
        for p in self.teacher.parameters():
            p.requires_grad = False
        self.num_samples = num_samples
        self.batch_size = batch_size
        self.local_steps = 0.0

    @torch.no_grad()
    def rollouts(self, model, inputs):
        b, s = inputs["input_ids"].shape
        input_ids = inputs["input_ids"].repeat_interleave(self.num_samples, dim=0)
        attention_mask = inputs["attention_mask"].repeat_interleave(self.num_samples, dim=0)

        all_translations = []
        total_samples = b * self.num_samples

        for start_idx in range(0, total_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, total_samples)
            batch_input_ids = input_ids[start_idx:end_idx]
            batch_attention_mask = attention_mask[start_idx:end_idx]

            outputs = model.generate(
                input_ids=batch_input_ids,
                attention_mask=batch_attention_mask,
                temperature=1.0,
                top_p=0.9,
                max_new_tokens=128,
                do_sample=True
            )

            generated_ids = outputs[:, s:].cpu()
            translations = self.student_tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            # print(translations)
            all_translations.extend(translations)

        grouped_translations = [
            all_translations[i:i + self.num_samples]
            for i in range(0, len(all_translations), self.num_samples)
        ]
        return grouped_translations

    def _compute_token_logps(self, model, model_inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        all_logps = []
        total_samples = model_inputs["input_ids"].size(0)

        for start_idx in range(0, total_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, total_samples)
            batch_inputs = {
                k: v[start_idx:end_idx] for k, v in model_inputs.items()
            }

            logits = model(
                input_ids=batch_inputs["input_ids"],
                attention_mask=batch_inputs["attention_mask"]
            ).logits

            shifted_logits = logits[:, :-1, :].float().contiguous()
            shifted_labels = batch_inputs["labels"][:, 1:].long().contiguous()
            b, s, vocab_size = shifted_logits.shape

            nll = F.cross_entropy(
                shifted_logits.view(-1, vocab_size),
                shifted_labels.view(-1),
                ignore_index=-100,
                reduction="none"
            ).view(b, s)

            all_logps.append(-nll.cpu())

        return torch.cat(all_logps, dim=0)

    def training_step(self, model, inputs, *args, **kwargs):
        if hasattr(self.optimizer, "train") and callable(self.optimizer.train):
            self.optimizer.train()

        inputs = self._prepare_inputs(inputs)
        student_tokenized, student_mask, advantages, rewards = self.compute_loss(model, inputs)
        losses = 0.0
        total = 0.0
        model.train()
        total_samples = student_tokenized["input_ids"].shape[0]
        for start_idx in range(0, total_samples, self.batch_size):
            end_idx = min(start_idx + self.batch_size, total_samples)
            batch_inputs = {
                k: v[start_idx:end_idx] for k, v in student_tokenized.items()
            }
            batch_advantages = advantages[start_idx:end_idx]
            student_token_logprobs = self._compute_token_logps(model=model, model_inputs=batch_inputs)
            batch_mask = student_mask[start_idx:end_idx]
            student_seq_logprobs = (student_token_logprobs * batch_mask).sum(-1)
            loss = -(batch_advantages * student_seq_logprobs).mean()
            self.accelerator.backward(loss, **kwargs)
            loss.detach()
            losses += loss
            total += 1
        self.local_steps += 1
        if self.local_steps % 5 == 0:
            print(
                {'train/mean_advantages': advantages.mean().item(),
                 'train/mean_rewards': rewards.mean().item(),
                 'train/std_advantages': advantages.std().item(),
                 'train/max_reward': rewards.max().item(),
                 'train/min_reward': rewards.min().item()
                 })

        return (losses / total).detach().to(model.device)

    def compute_loss(self, model, inputs, *args, **kwargs):
        model.eval()
        translated = list(itertools.chain.from_iterable(self.rollouts(model, inputs)))
        lengths = (inputs["source_attention_mask"] == 0).sum(-1)
        source_input_ids = []
        for idx, length in enumerate(lengths):
            source_input_ids.append(self.student_tokenizer.decode(inputs["source_input_ids"][idx, length:].tolist(), skip_special_tokens=True))
        source = list(itertools.chain.from_iterable([[s] * self.num_samples for s in source_input_ids]))
        translated = [t.replace('system\n', '').replace('assistant\n', '') for t in translated]
        assert len(source) == len(translated)
        source_translated = list(zip(source, translated))

        student_tokenized = tokenize_user_and_assistant(self.student_tokenizer, source_translated)
        student_tokenized = {k: v.to(model.device) for k, v in student_tokenized.items()}
        student_mask = (student_tokenized["labels"] != -100).float()[:, 1:].cpu()

        teacher_tokenized = tokenize_user_and_assistant(self.teacher_tokenizer, source_translated)
        teacher_tokenized = {k: v.to(model.device) for k, v in teacher_tokenized.items()}
        teacher_mask = (teacher_tokenized["labels"] != -100).float()[:, 1:].cpu()

        with torch.no_grad():
            lid_rewards = torch.tensor([lid_model.predict(translation.replace("\n", ""))[-1].item() for translation in translated]).log()
            teacher_token_logprobs = self._compute_token_logps(model=self.teacher, model_inputs=teacher_tokenized)
            teacher_seq_logprob = (teacher_token_logprobs * teacher_mask).sum(-1) / teacher_mask.sum(-1)

            src_lens = torch.tensor([len(s.split()) for s in source], device=model.device, dtype=torch.float)
            pred_lens = torch.tensor([len(t.split()) for t in translated], device=model.device, dtype=torch.float)
            len_ratio = torch.abs(src_lens - pred_lens)
            length_rewards = -torch.log(len_ratio + 1e-12).to(teacher_seq_logprob.device)

            rewards = lid_rewards + teacher_seq_logprob + length_rewards

            grouped_rewards = rewards.view(-1, self.num_samples)
            grouped_sum = grouped_rewards.sum(dim=-1, keepdim=True)
            baselines = (grouped_sum - grouped_rewards) / (self.num_samples - 1)
            baselines = baselines.view(-1)
            advantages = rewards - baselines

        return student_tokenized, student_mask, advantages.detach(), rewards.detach()

In [None]:
# Define the trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="rm",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="steps",
    save_steps=400,
    gradient_accumulation_steps=1,
    remove_unused_columns=False,
    bf16=True,
    logging_strategy="steps",
    logging_steps=5,
    optim="adamw_torch",
    lr_scheduler_type="linear",
    warmup_ratio=0.03,
    report_to='none',
)

In [None]:
trainer = RLOOTrainer(
    model=model,
    teacher=teacher,
    student_tokenizer=student_tokenizer,
    teacher_tokenizer=teacher_tokenizer,
    args=training_args,
    train_dataset=tokenized_training_data,
    data_collator=RewardDataCollatorWithPadding(tokenizer=student_tokenizer),
)

In [None]:
trainer.train()

In [None]:
test = raw_datasets.select(range(9000, 9200))

In [None]:
messages = [{"role": "user", "content": f"You are given a French text, provide faithful translation to English.\n\n{test[87]['source']}"}]
input_text = student_tokenizer.apply_chat_template(messages, tokenize=False)
inputs = student_tokenizer.encode(input_text, return_tensors="pt").to(trainer.model.device)
outputs = trainer.model.generate(inputs, max_new_tokens=128, temperature=0.2, top_p=0.9, do_sample=True)
print(student_tokenizer.decode(outputs[0]))

In [None]:
base = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct", dtype=torch.bfloat16, token="")

In [None]:
base = base.cuda()

In [None]:
messages = [{"role": "user", "content": f"You are given a French text, provide faithful translation to English.\n\n{test[87]['source']}"}]
input_text = student_tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)
inputs = student_tokenizer.encode(input_text, return_tensors="pt").to(trainer.model.device)
outputs = base.generate(inputs, max_new_tokens=128, temperature=0.2, top_p=0.9, do_sample=True)
print(student_tokenizer.decode(outputs[0]))

In [None]:
from tqdm import tqdm
from typing import List

@torch.no_grad()
def translate(corpus: List[str], model=None, tokenizer=None, batch_size: int=4, num_samples=1):
    """For each example in the batch, generate `num_samples` translations."""
    def tokenize(texts: List[str]):
        """Tokenize the texts"""
        conversations = [
            [{"role": "user",
              "content": f"You are given a French text, provide faithful translation to English.\n\n{text}"}]
            for text in texts]
        # TODO: call apply_chat_template from the tokenizer class with the right arguments to output torch tensors of the token ids
        # Which padding side to use? why?
        templated = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=True)
        tokenized = tokenizer(templated, padding_side="left", padding=True, return_tensors="pt")
        return tokenized
    data = sorted(enumerate(corpus), key=lambda x: tokenize([x[1]]).input_ids.shape[-1])

    results = [None] * len(corpus)

    for i in tqdm(range(0, len(data), batch_size)):
        indices, texts = zip(*data[i : i + batch_size])

        tokenized = tokenize(texts)
        input_ids = tokenized.input_ids.to(model.device)
        b, s = input_ids.shape
        input_ids = input_ids.repeat_interleave(num_samples, dim=0)
        attention_mask = tokenized.attention_mask.to(model.device)
        attention_mask = attention_mask.repeat_interleave(num_samples, dim=0)

        outputs = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            temperature=0.2,
            top_p=0.9,
            max_new_tokens=64,
            do_sample=True
        )
        outputs = outputs.view(b, num_samples, -1)

        for idx, predicted_ids in zip(indices, outputs):
            generated_tokens = predicted_ids[:, s:]
            translations = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
            results[idx] = translations

    return results

In [None]:
translated_base = translate(test["source"], model=base, tokenizer=student_tokenizer)

In [None]:
translated_student = translate(test["source"], model=trainer.model, tokenizer=student_tokenizer)

In [None]:
!pip install sacrebleu evaluate

In [None]:
test

In [None]:
import evaluate
metric = evaluate.load("sacrebleu")

In [None]:
base_candidates = [t[0] for t in translated_base]
student_candidates = [t[0] for t in translated_student]

references = [[t] for t in test["chosen"]]
print("base:", metric.compute(predictions=base_candidates, references=references)["score"])
print("rloo:", metric.compute(predictions=student_candidates, references=references)["score"])

In [None]:
base_candidates[0], student_candidates[0], references[0]