<a href="https://colab.research.google.com/github/yaya-sy/LLMReasoningCourse/blob/main/labs/lab3/lab3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Part 1: RLOO Implementation

In this part, you will implement RLOO for translation (but also for other tasks, if you wish) using reinforcement learning.
RLOO is a simple reinforcement learning algorithm for LLMs proposed in 2024. It is based on REINFORCE (1992).
You will read the paper (https://aclanthology.org/2024.acl-long.662.pdf
) and try to understand why the authors propose going back to such an old and simple algorithm, instead of using a much more recent one like PPO (2017).

1. Recall the pseudo code of REINFORCE


PPO introduces two important methods: (1) importance ratio clipping and (2) low variance reward estimation.

2. Why the authors of RLOO paper argue that clipping is not necessary in LLMs?
3. The authors argue that the way PPO computes the advantages might not be worth it. Why?
4. How advantages are computed in RLOO? What is the difference between how GRPO computes the advantages?
5. Is RLOO on-policy or off-policy RL algorithm?

In [None]:
import torch
from datasets import load_dataset

In [None]:
# load the model and the tokenizer. Load the model on the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = # Use lora if you don't have enough memory or want to scale the model size
reward_model  = # Use a reward model if needed.

In [None]:
raw_datasets = load_dataset("haoranxu/X-ALMA-Preference", split="train")
raw_datasets = raw_datasets.filter(lambda x: x["directions"] == "fr-en")
N = 1000
subset = raw_datasets.select(range(N))
# use 'chosen' as gold English translations and 'source' as French sentences to translate to English

In [None]:
# prepare and tokenize the dataset

In [None]:
@dataclass
class DataCollatorWithPadding:
    tokenizer: AutoTokenizer

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        features_conv = [{"input_ids": f["input_ids"].squeeze(), "attention_mask": f["attention_mask"].squeeze()} for f in features]
        features_source = [{"input_ids": f["source_input_ids"].squeeze()} for f in features]

        batch_conv = self.tokenizer.pad(
            features_conv,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )

        batch_source = self.tokenizer.pad(
            features_source,
            padding=True,
            padding_side="left",
            return_tensors="pt",
        )
        return {
            "input_ids": batch_conv["input_ids"],
            "attention_mask": batch_conv["attention_mask"],
            "source_input_ids": batch_source["input_ids"],
            "source_attention_mask": batch_source["attention_mask"],
        }

In [None]:
def tokenize_user_and_assistant(tokenizer, texts: List[str]):
      conversations = [
          [
              {"role": "user",
              "content": f"You are given a French text, provide faithful translation to English.\n\n{text}"
              },
              {"role": "assistant",
              "content": f"{translation}"
              }
          ]
          for text, translation in texts
      ]

      templated = tokenizer.apply_chat_template(conversations, tokenize=False, add_generation_prompt=False)
      tokenized = tokenizer(templated, padding_side="right", padding=True, return_tensors="pt")

      # Loss is computed only on the assistant answers
      prompts_only = [
          tokenizer.apply_chat_template(
              [conv[0]],
              tokenize=False,
              add_generation_prompt=True
          )
          for conv in conversations
      ]
      prompt_tokenized = tokenizer(prompts_only, padding_side="right", padding=True, return_tensors="pt")
      prompt_lengths = (prompt_tokenized.input_ids != tokenizer.pad_token_id).sum(dim=1)

      # Now mask the source sentence (user message), keep assistant response (translation)
      labels = tokenized.input_ids.clone()
      labels[labels == tokenizer.pad_token_id] = -100

      for i, prompt_len in enumerate(prompt_lengths):
          labels[i, :prompt_len] = -100

      tokenized["labels"] = labels
      return tokenized

In [None]:
class RLOOTrainer(Trainer):
    def __init__(self, model, reward_model, student_tokenizer, reward_tokenizer, num_samples: int = 8, *args, **kwargs):
        super(RLOOTrainer, self).__init__(model=model, *args, **kwargs)
        self.student_tokenizer = student_tokenizer
        self.reward_tokenizer = reward_tokenizer
        self.reward_model = reward_model
        self.reward_model.eval()
        for p in self.reward_model.parameters():
            p.requires_grad = False
        self.num_samples = num_samples

    @torch.no_grad()
    def rollouts(self, model, inputs):
        # will generate translations from the model, given the tokenized inputs
        return translations

    def _compute_token_logps(self, model, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
        # computes the logprobs on the inputs under the model

    def compute_loss(self, model, inputs, *args, **kwargs):
        model.eval()
        translated = list(itertools.chain.from_iterable(self.rollouts(model, inputs)))
        lengths = (inputs["source_attention_mask"] == 0).sum(-1)
        source_input_ids = []
        for idx, length in enumerate(lengths):
            source_input_ids.append(self.student_tokenizer.decode(inputs["source_input_ids"][idx, length:].tolist(), skip_special_tokens=True))
        source = list(itertools.chain.from_iterable([[s] * self.num_samples for s in source_input_ids]))
        # remove 'system\n' and 'assistant\n
        translated = [t.replace('system\n', '').replace('assistant\n', '') for t in translated]
        assert len(source) == len(translated)
        source_translated = list(zip(source, translated))

        tokenized = tokenize_user_and_assistant(self.student_tokenizer, source_translated)
        tokenized = {k: v.to(model.device) for k, v in student_tokenized.items()}
        mask = (student_tokenized["labels"] != -100).float()
        mask = student_mask[:, 1:]

        reward_tokenized = "..."
        reward_mask = "..." # if needed
        reward_mask = teacher_mask[:, 1:]

        policy_token_logprobs = self._compute_token_logps(model=model, model_inputs=tokenized)

        with torch.no_grad():
            # compute the rewards and advantages

        model.train()
        student_seq_logprobs = (policy_token_logprobs * mask).sum(-1)
        loss = -(advantages.detach() * policy_token_logprobs).mean()
        if self.state.global_step % 10 == 0:
            with torch.no_grad():
                metrics = {
                    "train/mean_reward": rewards.mean().item(),
                    "train/mean_student_logprob": student_seq_logprobs.mean().item(),
                }
                print(metrics)
        return loss

If you're working on Translation Reinforcement Learning, you have different choices of reward. You can manually your own reward function (very difficult!). Maybe you can also use translation evaluation models:
- BERT-Score
- Reference-Free COMET models
- Multilingual Embedding Models

In [None]:
# Define the trainer
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="rm",
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="steps",
    save_steps=400,
    gradient_accumulation_steps=1,
    remove_unused_columns=False,
    bf16=True,
    logging_strategy="steps",
    logging_steps=10,
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    report_to='none',
)

In [None]:
trainer = RLOOTrainer(
    model=model,
    teacher=teacher,
    student_tokenizer=student_tokenizer,
    teacher_tokenizer=teacher_tokenizer,
    args=training_args,
    train_dataset=tokenized_training_data,
    data_collator=RewardDataCollatorWithPadding(tokenizer=student_tokenizer),
)

In [None]:
trainer.train()

# Part 2 - Training agentic poet