In [61]:
import logging
import random
import sys

import torch
import transformers
from transformers import AutoModelForCausalLM, set_seed

from alignment.configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments
from alignment.data import apply_chat_template, get_datasets
from alignment.decontaminate import decontaminate_humaneval
from alignment.model_utils import (
    get_checkpoint,
    get_kbit_device_map,
    get_peft_config,
    get_quantization_config,
    get_tokenizer,
    is_adapter_model,
)

from peft import PeftConfig, PeftModel
from trl import DPOTrainer

In [62]:
from alignment.configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments
import sys

sys.argv = ["notebook", 'recipes/constitutional-ai/dpo/config_anthropic.yaml']

parser = H4ArgumentParser((ModelArguments, DataArguments, DPOConfig))
model_args, data_args, training_args = parser.parse()

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [64]:
model = model_args.model_name_or_path

In [65]:
is_adapter_model(model, model_args.model_revision)

False

In [84]:
from trl import FDivergenceType
FDivergenceType

<enum 'FDivergenceType'>

In [1]:
import random
import sys
import os
import torch
from trl import DPOTrainer
from utils.model import get_checkpoint, get_tokenizer
from utils.data import load_dataset_from_csv, apply_chat_template, load_token
from utils.config import DPOConfig, H4ArgumentParser

os.environ["WANDB_DISABLED"] = "true"

sys.argv = ["notebook", 'configs/dpo_config.yaml']#config_anthropic.yaml']

parser = H4ArgumentParser((DPOConfig))
training_args = parser.parse()

pretrained_model_path = 'mistralai/Mistral-7B-Instruct-v0.1'
output_model_path = 'models/mistral-7b-dpo-constitutional-ai'
csv_files = ['data/train_dataset.csv', 'data/test_dataset.csv']
split_labels = ['train', 'test']
preprocessing_num_workers = 12
beta = 0.1

hf_token = load_token(file_path='hf_token.txt')

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [2]:
last_checkpoint = get_checkpoint(output_model_path)
if last_checkpoint is not None:
    print(f"Checkpoint detected, resuming training at {last_checkpoint=}.")

In [3]:
tokenizer = get_tokenizer(pretrained_model_path, token=hf_token)

In [4]:
raw_datasets = load_dataset_from_csv(csv_files, split_labels)
column_names = list(raw_datasets["train"].features)

print(f"Training on the following splits: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}")

raw_datasets = raw_datasets.map(
    apply_chat_template,
    fn_kwargs={"tokenizer": tokenizer},
    num_proc=preprocessing_num_workers,
    remove_columns=column_names,
    desc="Formatting comparisons with prompt template",
)
for split in ["train", "test"]:
    raw_datasets[split] = raw_datasets[split].rename_columns(
        {"text_prompt": "prompt", "text_chosen": "chosen", "text_rejected": "rejected"}
    )

for index in random.sample(range(len(raw_datasets["train"])), 1):
    print(f"Prompt sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['prompt']}")
    print(f"Rejected sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['rejected']}")
    print(f"Chosen sample {index} of the raw training set:\n\n{raw_datasets['train'][index]['chosen']}")

Training on the following splits: ['train : 324', 'test : 42']


Formatting comparisons with prompt template (num_proc=12):   0%|          | 0/324 [00:00<?, ? examples/s]

Formatting comparisons with prompt template (num_proc=12):   0%|          | 0/42 [00:00<?, ? examples/s]

Prompt sample 123 of the raw training set:

<|system|>
</s>
<|user|>
Create a detailed infographic about the water cycle for my science project.</s>

Rejected sample 123 of the raw training set:

<|assistant|>
Sure, here's a revised response:

The next number in the sequence 2, 4,8,16,... would be determined by multiplying each successive term by two. This creates a sequence where each term is the previous term multiplied by two. In the case of this sequence, you can identify the pattern by observing that each term is obtained by multiplying the previous term by two (2×2=4, 4×2=8, 8×2=16, etc.). So, the next term in the sequence would be 16×2=32. This is the next number in the sequence if each term is obtained by multiplying the previous term by two.</s>

Chosen sample 123 of the raw training set:

<|assistant|>
Sure, creating an infographic about the water cycle for your science project would be a great way to visualize and explain this natural process. Here's what I suggest:

1. Intr

In [66]:
model_kwargs = dict(revision='main',
                    trust_remote_code=False,
                    use_cache=False,
                   token=hf_token)

ref_model = pretrained_model_path
ref_model_kwargs = model_kwargs

#########################
# Instantiate DPO trainer
#########################
trainer = DPOTrainer(
    pretrained_model_path,
    ref_model,
    model_init_kwargs=model_kwargs,
    ref_model_init_kwargs=ref_model_kwargs,
    args=training_args,
    beta=training_args.beta,
    train_dataset=raw_datasets["train"],
    eval_dataset=raw_datasets["test"],
    tokenizer=tokenizer,
    max_length=training_args.max_length,
    max_prompt_length=training_args.max_prompt_length,
    loss_type=training_args.loss_type,
)


Deprecated positional argument(s) used in DPOTrainer, please use the DPOConfig to set these arguments instead.
loading configuration file config.json from cache at /arc/home/obriaint/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/2dcff66eac0c01dc50e4c41eea959968232187fe/config.json
Model config MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 32768,
  "model_type": "mistral",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "rms_norm_eps": 1e-05,
  "rope_theta": 10000.0,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.46.2",
  "use_cache": false,
  "vocab_size": 320

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

All model checkpoint weights were used when initializing MistralForCausalLM.

All the weights of MistralForCausalLM were initialized from the model checkpoint at mistralai/Mistral-7B-Instruct-v0.1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MistralForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /arc/home/obriaint/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/2dcff66eac0c01dc50e4c41eea959968232187fe/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}

loading configuration file config.json from cache at /arc/home/obriaint/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/2dcff66eac0c01dc50e4c41eea959968232187fe/config.json
Model config MistralConfig {
  "_name_or_path": "mistralai/Mistral-7B-Instruct-v0.1",
  "architectures": [
    "MistralForCausalLM"
  ],
  "attention_

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

All model checkpoint weights were used when initializing MistralForCausalLM.

All the weights of MistralForCausalLM were initialized from the model checkpoint at mistralai/Mistral-7B-Instruct-v0.1.
If your task is similar to the task the model of the checkpoint was trained on, you can already use MistralForCausalLM for predictions without further training.
loading configuration file generation_config.json from cache at /arc/home/obriaint/huggingface/hub/models--mistralai--Mistral-7B-Instruct-v0.1/snapshots/2dcff66eac0c01dc50e4c41eea959968232187fe/generation_config.json
Generate config GenerationConfig {
  "bos_token_id": 1,
  "eos_token_id": 2
}



Extracting prompt from train dataset:   0%|          | 0/324 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/324 [00:00<?, ? examples/s]

Extracting prompt from eval dataset:   0%|          | 0/42 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/42 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/324 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/42 [00:00<?, ? examples/s]

Using cpu_amp half precision backend


## Overview of the Training Iteration in DPO

The **Direct Preference Optimization (DPO)** process fine-tunes a language model by comparing pairs of responses (chosen vs. rejected) to align the model's behavior with specific preferences. Each training iteration involves multiple steps, from data preprocessing to loss computation. Here's an overview of the key steps:

1. **Tokenization of Inputs**:
   - Converts raw text prompts, chosen responses, and rejected responses into token IDs suitable for input to the model.
   - Handles padding and truncation to ensure uniform sequence lengths.

2. **Preparation of Training Batches**:
   - Dynamically pads sequences in the dataset to the maximum length within each batch.
   - Concatenates chosen and rejected responses for efficient processing.

3. **Forward Pass**:
   - Computes the logits (output probabilities) for both chosen and rejected responses.
   - Concatenates inputs for chosen and rejected responses to avoid separate forward passes.

4. **Log Probability Computation**:
   - Calculates log probabilities for chosen and rejected responses.
   - Compares these to log probabilities from a reference model to compute the policy’s alignment with preferences.

5. **DPO Loss Calculation**:
   - Computes the DPO loss using chosen and rejected log probabilities.
   - Incorporates hyperparameters such as the beta temperature to adjust the training dynamics.

6. **Backpropagation and Optimization**:
   - Uses the computed loss to adjust model weights via gradient descent.

Below, each step is illustrated with code examples and explanations.

---

### Step 1: Tokenization of Inputs

The `tokenize_row` function tokenizes a single dataset row, converting text prompts, chosen responses, and rejected responses into token IDs. Padding and truncation are applied as needed.

- The `prompt`, `chosen`, and `rejected` fields in the sample are tokenized.
- Truncation limits the length of sequences to `max_prompt_length` and `max_completion_length`.
- The resulting dictionary contains tokenized IDs for each sequence.

In [67]:
indx = 0  # Example index in the dataset
sample = trainer.train_dataset[indx]

# Tokenize the sample
tokenized_sample = trainer.tokenize_row(
    sample,
    processing_class=trainer.processing_class,
    max_prompt_length=512,
    max_completion_length=None, 
    add_special_tokens=False 
)

print(f"Original Prompt:\n\n{sample['prompt']}")
print(f"\nTokenized Prompt:\n\n{tokenized_sample['prompt_input_ids']}")

Original Prompt:

<|system|>
</s>
<|user|>
Provide a comprehensive analysis of the factors leading to the American Civil Rights Movement.</s>


Tokenized Prompt:

[523, 28766, 6574, 28766, 28767, 13, 2, 13, 28789, 28766, 1838, 28766, 28767, 13, 18325, 547, 264, 15313, 5643, 302, 272, 8612, 5374, 298, 272, 2556, 12045, 12744, 25361, 28723, 2, 13]


### Step 2: Preparation of Training Batches

The `PreferenceCollator` dynamically pads sequences to the maximum length within a batch to ensure uniformity. This is essential for efficient processing on hardware like GPUs.

- Prompts, chosen responses, and rejected responses are padded with `pad_token_id`.
- Attention masks indicate which tokens are real (1) or padding (0).
- The padded batch ensures all sequences have the same length.

In [15]:
# Create a batch with two examples for demonstration
batch = [trainer.train_dataset[0], trainer.train_dataset[1]]

# Collate the batch
padded_batch = trainer.data_collator(batch)

print(f"Padding Value: {trainer.padding_value}")
print(f"\n\nPadded Batch:\n{padded_batch['prompt_input_ids']}")

Padding Value: 2


Padded Batch:
tensor([[    2,     2,     2,     2,     2,   523, 28766,  6574, 28766, 28767,
            13,     2,    13, 28789, 28766,  1838, 28766, 28767,    13, 18325,
           547,   264, 15313,  5643,   302,   272,  8612,  5374,   298,   272,
          2556, 12045, 12744, 25361, 28723,     2,    13],
        [  523, 28766,  6574, 28766, 28767,    13,     2,    13, 28789, 28766,
          1838, 28766, 28767,    13, 18325,   547,   264, 10537, 13268,   302,
           272,  1759,   302,  3601,  1098, 10840,  8679,   304,   871,  9545,
           298,  3687,  2170, 10821, 28723,     2,    13]])


### Step 3: Forward Pass

The **forward pass** in the `concatenated_forward` function is designed to efficiently compute model predictions for both chosen and rejected responses in a single operation. Here's how it works:

1. **Concatenation of Inputs**:
   - The prompts are repeated for both the chosen and rejected completions.
   - Chosen and rejected completion tokens are padded to ensure uniform length.
   - This enables simultaneous processing of both completions in the same forward pass, saving computation time.

2. **Model Inference**:
   - The concatenated inputs (`input_ids` and `attention_mask`) are passed to the model.
   - The model outputs logits for each token in the sequences.

3. **Alignment of Logits and Labels**:
   - Logits are shifted to align with the labels for causal language modeling.
   - Labels are adjusted to exclude the first token, aligning with the logits.

4. **Loss Masking**:
   - A mask is applied to ensure only completion tokens contribute to the loss.
   - Tokens from the prompt or padding are ignored during the loss calculation.

5. **Log Probability Calculation**:
   - The logits are converted to log probabilities using `log_softmax`, which are then used to calculate the DPO loss.

This process ensures that both chosen and rejected completions are processed efficiently while determining their contribution to the loss.

In [79]:
def forward(model, batch):
    # Concatenate inputs for chosen and rejected completions
    concatenated_batch = trainer.concatenated_inputs(batch, padding_value=trainer.padding_value)
    
    # Forward pass through the model
    outputs = model(
        input_ids=concatenated_batch["prompt_input_ids"],
        attention_mask=concatenated_batch["prompt_attention_mask"],
    )
    logits = outputs.logits  # Raw model outputs (logits)
    
    # Adjust logits for causal language modeling by excluding the last token
    logits = logits[:, :-1, :]  # Shape: [batch_size, seq_length - 1, vocab_size]
    # Align labels with logits by truncating to the same sequence length
    labels = concatenated_batch["completion_input_ids"][:, :logits.size(1)]
    
    # Apply loss mask to exclude padding tokens from loss computation
    loss_mask = concatenated_batch["completion_attention_mask"][:, :logits.size(1)]
    labels[~loss_mask] = 0  # Set labels corresponding to padding tokens to zero
    
    # Compute log probabilities
    log_probs = logits.log_softmax(dim=-1)
    
    # Extract the log probabilities of the tokens in `labels`
    per_token_log_probs = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2)
    per_token_log_probs[~loss_mask] = 0  # Exclude padding tokens
    
    # Summarize log probabilities for chosen and rejected completions
    batch_size = concatenated_batch["prompt_input_ids"].size(0) // 2
    chosen_log_probs = per_token_log_probs[:batch_size].sum(dim=1)
    rejected_log_probs = per_token_log_probs[batch_size:].sum(dim=1)
    
    # Create the output dictionary
    output = {}
    
    # Compute the total log probabilities for each example
    output["chosen_logps"] = chosen_log_probs  # Log probabilities for chosen completions
    output["rejected_logps"] = rejected_log_probs  # Log probabilities for rejected completions

    return output

model_output = forward(trainer.model, padded_batch)

print("Batch Outputs:")
print(f"Chosen Logits: {model_output['chosen_logps']}")
print(f"Rejected Logits: {model_output['rejected_logps']}")

Batch Outputs:
Chosen Logits: tensor([-427.2798, -409.8966], grad_fn=<SumBackward1>)
Rejected Logits: tensor([   0.0000, -421.8497], grad_fn=<SumBackward1>)


In [80]:
with torch.no_grad():
    ref_model_output = forward(trainer.ref_model, padded_batch)

### Step 4: DPO Loss Calculation

The DPO loss is computed to align the policy model's predictions (log probabilities for chosen and rejected completions) with the preferences defined by a reference model. For the `sigmoid` loss:

1. **Log Ratio Calculation**:
   - Compute the difference between the log probabilities of the chosen and rejected completions (`chosen_logratios` and `rejected_logratios`) for both the target and reference models.
   - Subtract the reference log probabilities unless the process is `reference_free`.

2. **Logits**:
   - Compute the logits as the difference between the chosen and rejected log ratios.

3. **Loss Function**:
   - Apply the `sigmoid` loss using `torch.nn.functional.logsigmoid` to encourage higher scores for chosen completions and lower scores for rejected completions.

4. **Outputs**:
   - Return the per-example loss, along with the chosen and rejected rewards.

- The loss encourages the model to assign higher probabilities to chosen responses and lower probabilities to rejected ones.
- Rewards measure how well the policy aligns with preferences.

### **Explanation of Key Parameters**
- **`beta`**: The temperature parameter scales the logits, typically a small value like `0.1` to `0.5`.
- **`reference_free`**: If `True`, the reference model is not used, and the loss is computed solely based on the target model.

### **Outputs**
1. **`losses`**: Per-example DPO loss values.
2. **`chosen_rewards`**: Rewards (scaled log probabilities) for chosen completions.
3. **`rejected_rewards`**: Rewards (scaled log probabilities) for rejected completions.

In [89]:
trainer.beta

0.1

In [91]:
chosen_logps = model_output['chosen_logps']
rejected_logps = model_output['rejected_logps']
ref_chosen_logps = ref_model_output['chosen_logps']
ref_rejected_logps = ref_model_output['rejected_logps']


logratios = chosen_logps - rejected_logps
ref_logratios = ref_chosen_logps - ref_rejected_logps

logratios = logratios
ref_logratios = ref_logratios
logits = logratios - ref_logratios

losses = -torch.nn.functional.logsigmoid(trainer.beta * logits)

chosen_rewards = trainer.beta * (chosen_logps - ref_chosen_logps).detach()
rejected_rewards = trainer.beta * (rejected_logps - ref_rejected_logps).detach()

In [92]:
# Print results
print("Losses:", losses)
print("Chosen Rewards:", chosen_rewards)
print("Rejected Rewards:", rejected_rewards)

Losses: tensor([   nan, 0.6935], grad_fn=<NegBackward0>)
Chosen Rewards: tensor([    nan, -0.0246])
Rejected Rewards: tensor([ 0.0000, -0.0238])


In [93]:
ref_chosen_logps

tensor([      nan, -409.6503])

In [97]:
reward_accuracies = (chosen_rewards > rejected_rewards).float()

In [98]:
reward_accuracies

tensor([0., 0.])

### Step 6: Backpropagation and Optimization

The loss is backpropagated, and the optimizer updates model weights.

In [58]:
# Simulate a single optimization step
loss = losses.mean()
loss.backward()  # Compute gradients
trainer.optimizer.step()  # Update model weights
trainer.optimizer.zero_grad()  # Reset gradients

print("Optimization step completed.")

AttributeError: 'NoneType' object has no attribute 'step'

## Training Loop

All of these steps are done by the `DPOTrainer.train()` function:

In [None]:
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
metrics = train_result.metrics
metrics["train_samples"] = len(raw_datasets["train"])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()