# Exercise 4 (optional): Setup an Online DPO/RPO training 🚀🚀🚀
## 📘 Prerequisites
* Exercise 1: Smart dataset sampling for quality
* Exercise 2: Resource-efficient training strategies
* Exercise 3: Optimizing RPO Training with α Parameter Tuning 

## 🎯 The Challenge
**Online DPO Training Loop**: Using the `OnlineDPOTrainer` and `OnlineDPOConfig` classes from trl lib; please add in trlabs.rl.train a new training function to run an onlinedpo training

```python
from trlabs.rl.train import onlinedpo

onlinedpo(data_params, training_params, model_config)
```

### Install dependencies

In [3]:
! pip install -r requirements.txt
! pip install flash-attn==2.7.3 --no-build-isolation

### Testing GPU
Please check if python recognize that you have GPU allocated, if not please go in `Settings`>`Accelerator`>`GPU T4 x 2` 

In [4]:
import os, sys
#from tensorflow.python.client import device_lib
repo_folder = os.getcwd().split('labs_AMLD25_Workshop')[0][:-1]+"/labs_AMLD25_Workshop/src" 
sys.path.append(repo_folder)

#device_lib.list_local_devices()

if you get two GPUs you can manually assign them using env variables. This step is optional since they should be automatically recognized by pytorch 

In [None]:
os.environ["WANDB_DISABLED"] = "true" ## turning off WandB logging
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

In [None]:
import torch

from typing import Optional, List, Dict
import datasets
from datasets import (
    load_dataset, 
    load_from_disk, 
    DatasetDict,
    concatenate_datasets
)

from accelerate import Accelerator, PartialState
from transformers import AutoModelForCausalLM, AutoTokenizer


from trl import (
    ModelConfig,
    DPOTrainer,
    DPOConfig,
    TrlParser,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
)

from trlabs.rl.data import get_datasets
from trlabs.rl.train import dpo

from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE

### Model Config

In [7]:
model_config = {
    "model_name_or_path": "Qwen/Qwen2-0.5B-Instruct",
    "torch_dtype": "bfloat16",
    ##"attn_implementation": "flash_attention_2",
    "use_peft": True, 
    "lora_r": 32,
    "lora_alpha": 16,  
}

### Data Config

In [8]:
data_params = {
  "dataset_name": "Mix 2",
  "dataset_mixer": {
    "trl-lib/ultrafeedback_binarized": 0.01,
    "./data/AMLD25_reuters_gentitle_1k": 0.5,
  },
  "dataset_splits": ["train", "test"],
  "seed": 42
}

### Training Config

In [9]:
training_params =  {
    ## RPO loss active 
    ## alpha is the multiplier of NLL loss
    "rpo_alpha": 1.,
    ## General
    "output_dir": f"{model_config['model_name_or_path'].split('/')[0].lower()}_ex4_output",
    "num_train_epochs": 1,
    "learning_rate": 5.0e-7,
    "eval_strategy": "steps",
    "eval_steps": 10,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 1,
    "max_length": 1024,
    "max_prompt_length":512,
    ## Optimizer
    "optim": "adamw_torch",
    "learning_rate": 2.0e-7,
    "weight_decay": 0.001,
    "adam_epsilon": 1.0e-8,
    "adam_beta1": 0.9,
    "adam_beta2": 0.999,
    "max_grad_norm": 1.0,
    ## Scheduler ##
    "warmup_steps": 10,
    "lr_scheduler_type": "cosine",
    ## Logging
    "log_level": "info",
    "logging_first_step": True,
    "logging_steps": 10
}

Please, pay attention that OnlineDPOTrainer requires a Judge model or a Reward Model to assess the generation. This requires another model to be allocated in the GPU memory making the limited HW a strong constraint. 

For this excercise, please use either a ligth Judge or an model API (see [here](https://huggingface.co/docs/trl/main/en/judges), and pick the best).

The scope of this excersise is to familiarize with the online DPO and its integration (see [here](https://huggingface.co/docs/trl/online_dpo_trainer)); 


#### Note:
2 T4 GPUs may not be enough to run the training.

```python
class OnlineDPOTrainer(Trainer):
    r"""
    Initialize OnlineDPOTrainer.

    Args:
        model (`transformers.PreTrainedModel` or `torch.nn.Module`):
            The model to train, preferably an `AutoModelForCausalLM`.
        ref_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
            The reference model to use for training. If None is specified, the reference model will be created from
            the model.
        reward_model (`transformers.PreTrainedModel` or `torch.nn.Module` or `None`):
            The reward model to score completions with, preferably an `AutoModelForSequenceClassification`.
        judge (`BasePairwiseJudge`):
            The judge to use for pairwise comparison of model completions.
        args (`OnlineDPOConfig`):
            The online DPO config arguments to use for training.
        data_collator (`transformers.DataCollator`):
            The data collator to use for training. If None is specified, the default data collator (`DPODataCollatorWithPadding`) will be used
            which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences.
        train_dataset (`datasets.Dataset`):
            The dataset to use for training.
        eval_dataset (`datasets.Dataset`):
            The dataset to use for evaluation.
        processing_class (`PreTrainedTokenizerBase` or `BaseImageProcessor` or `FeatureExtractionMixin` or `ProcessorMixin`, *optional*):
            Processing class used to process the data. If provided, will be used to automatically process the inputs
            for the model, and it will be saved along the model to make it easier to rerun an interrupted training or
            reuse the fine-tuned model.
        peft_config (`dict`):
            The peft config to use for training.
        compute_metrics (`Callable[[EvalPrediction], dict]`, *optional*):
            The function to use to compute the metrics. Must take a `EvalPrediction` and return
            a dictionary string to metric values.
        callbacks (`list[transformers.TrainerCallback]`):
            The callbacks to use for training.
        optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`):
            The optimizer and scheduler to use for training.
        preprocess_logits_for_metrics (`Callable[[torch.Tensor, torch.Tensor], torch.Tensor]`):
            The function to use to preprocess the logits before computing the metrics.
    """
```