In [1]:
import os
import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import Dict
from trl import DPOTrainer, DPOConfig

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

# 1. load a pretrained model and tokenizer

In [12]:
model_name_or_path = "openai-community/gpt2"
ignore_bias_buffers = False

model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
if ignore_bias_buffers:
    # torch distributed hack
    model._ddp_params_and_buffers_to_ignore = [
        name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
    ]

model_ref = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

The DPO trainer expects a model of AutoModelForCausalLM, compared to PPO that expects AutoModelForCausalLMWithValueHead for the value function.

## 2. Load the Anthropic Helpful-Harmless dataset

In [13]:
def extract_anthropic_prompt(prompt_and_response):
    """Extract the anthropic prompt from a prompt and response pair."""
    search_term = "\n\nAssistant:"
    search_term_idx = prompt_and_response.rfind(search_term)
    assert search_term_idx != -1, f"Prompt and response does not contain '{search_term}'"
    return prompt_and_response[: search_term_idx + len(search_term)]

def get_hh(split: str, sanity_check: bool = False, silent: bool = False, cache_dir: str = None) -> Dataset:
    """Load the Anthropic Helpful-Harmless dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }

    Prompts should be structured as follows:
      \n\nHuman: <prompt>\n\nAssistant:
    Multiple turns are allowed, but the prompt should always start with \n\nHuman: and end with \n\nAssistant:.
    """

    dataset = load_dataset("Anthropic/hh-rlhf", split=split, cache_dir=cache_dir)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))

    def split_prompt_and_responses(sample) -> Dict[str, str]:
        prompt = extract_anthropic_prompt(sample["chosen"])
        return {
            "prompt": prompt,
            "chosen": sample["chosen"][len(prompt) :],
            "rejected": sample["rejected"][len(prompt) :],
        }

    return dataset.map(split_prompt_and_responses)

sanity_check = True
train_dataset = get_hh("train", sanity_check=sanity_check)
eval_dataset = get_hh("test", sanity_check=sanity_check)

In [14]:
train_dataset

Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 1000
})

In [15]:
eval_dataset

Dataset({
    features: ['chosen', 'rejected', 'prompt'],
    num_rows: 1000
})

# 3. initialize training arguments:

In [16]:
training_args = DPOConfig(
		num_train_epochs=3,
		learning_rate=5e-07,
		per_device_train_batch_size=1,
		do_eval=True,
		per_device_eval_batch_size=1,
		adam_epsilon=1e-08,
		lr_scheduler_type="linear",
		warmup_ratio=0.1,
		seed=42,
		logging_steps=100,
		save_steps=500,
		save_strategy="steps",
		output_dir="./output-dir",
		gradient_checkpointing=True,
		bf16=True,
		remove_unused_columns=False,
	)

# 4. initialize the DPO trainer

In [17]:
dpo_trainer = DPOTrainer(
		model=model,
        ref_model=model_ref,
		args=training_args,
		train_dataset=train_dataset,
		eval_dataset=eval_dataset,
		processing_class=tokenizer,
	)

Applying chat template to train dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/1000 [00:00<?, ? examples/s]

# 5. Train

In [18]:
dpo_trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
100,0.6944
200,0.6909
300,0.666
400,0.6922
500,0.6683
600,0.6414
700,0.7012
800,0.6347
900,0.7045
1000,0.6764


TrainOutput(global_step=3000, training_loss=0.6127150751749675, metrics={'train_runtime': 2837.512, 'train_samples_per_second': 1.057, 'train_steps_per_second': 1.057, 'total_flos': 0.0, 'train_loss': 0.6127150751749675, 'epoch': 3.0})

In [19]:
dpo_trainer.save_model()