In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from typing import Optional
from datasets import Dataset, load_dataset

def get_imdb(
        split : str, sanity_check : bool = False, silent : bool = False, cache_dir : Optional[str] = None,
        seed : int = 42
    ) -> Dataset:
    """Load the IMDb dataset from Hugging Face and convert it to the necessary format.

    The dataset is converted to a dictionary with the following structure:
    {
        'prompt': List[str],
        'chosen': List[str],
        'rejected': List[str],
    }
    """
    dataset = load_dataset("imdb", split=split, cache_dir=cache_dir).shuffle(seed=seed)
    if sanity_check:
        dataset = dataset.select(range(min(len(dataset), 1000)))
    
    # create pairs of negative and positive examples
    neg_dataset = dataset.filter(lambda x: x["label"] == 0, batched=False)
    pos_dataset = dataset.filter(lambda x: x["label"] == 1, batched=False)

    length = min(len(neg_dataset), len(pos_dataset))

    return Dataset.from_dict({
        "prompt": ["Movie review:"] * length,
        "chosen": pos_dataset["text"][:length],
        "rejected": neg_dataset["text"][:length]
    })

dataset = get_imdb("train", sanity_check=True)
len(dataset)

488

In [3]:
import torch
import transformers

from pyreft import (
    get_reft_model,
    ReftConfig,
    LoreftIntervention
)

# loading huggingface model
model_name_or_path = "gpt2"
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name_or_path, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name_or_path)

# wrap the model with rank-1 constant reft
reft_config = ReftConfig(representations=[{
    "component": f"transformer.h[{i}].output", # string access to the model component
    # "intervention": ConsreftIntervention(
    #     embed_dim=model.config.hidden_size, low_rank_dimension=1
    # )
    "intervention": LoreftIntervention(
        embed_dim=model.config.hidden_size, low_rank_dimension=64
    )
} for i in [10]])
reft_model = get_reft_model(model, reft_config)
reft_model.print_trainable_parameters()

trainable intervention params: 98,368 || trainable model params: 0
model params: 124,439,808 || trainable%: 0.0790486594129107


In [6]:
from transformers import TrainingArguments
from dpo import DPOReftTrainer

tokenizer.pad_token = tokenizer.eos_token

training_args = TrainingArguments(
    learning_rate=1e-3,
    # max_steps=500,
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    logging_steps=10,
    eval_steps=100,
    output_dir="pv_dpo_example",
    # optim="rmsprop",
    warmup_steps=150,
    report_to='none', # "wandb",
    logging_first_step=True
)

beta = 0.1
max_length = 256
max_prompt_length = 8
generate_during_eval = False

trainer = DPOReftTrainer(
    reft_model,
    reft_model,
    args=training_args,
    beta=beta,
    train_dataset=dataset,
    eval_dataset=dataset,
    tokenizer=tokenizer,
    max_length=max_length,
    max_target_length=max_length,
    max_prompt_length=max_prompt_length,
    generate_during_eval=generate_during_eval,
    peft_config=None,
)
trainer.train()



Map:   0%|          | 0/488 [00:00<?, ? examples/s]

Map:   0%|          | 0/488 [00:00<?, ? examples/s]

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


  0%|          | 0/244 [00:00<?, ?it/s]

{'loss': 0.5586, 'grad_norm': 4.250195503234863, 'learning_rate': 6.6666666666666675e-06, 'rewards/chosen': 0.10009765625, 'rewards/rejected': -0.2001953125, 'rewards/accuracies': 0.5, 'rewards/margins': 0.30078125, 'logps/rejected': -760.0, 'logps/chosen': -868.0, 'logits/rejected': -107.5, 'logits/chosen': -108.5, 'epoch': 0.01}
{'loss': 0.6309, 'grad_norm': 4.609270095825195, 'learning_rate': 6.666666666666667e-05, 'rewards/chosen': -0.1640625, 'rewards/rejected': -0.39453125, 'rewards/accuracies': 0.5555555820465088, 'rewards/margins': 0.23046875, 'logps/rejected': -764.0, 'logps/chosen': -756.0, 'logits/rejected': -108.0, 'logits/chosen': -108.0, 'epoch': 0.08}
{'loss': 0.7527, 'grad_norm': 15.571920394897461, 'learning_rate': 0.00013333333333333334, 'rewards/chosen': -0.87109375, 'rewards/rejected': -1.0234375, 'rewards/accuracies': 0.550000011920929, 'rewards/margins': 0.150390625, 'logps/rejected': -744.0, 'logps/chosen': -756.0, 'logits/rejected': -107.5, 'logits/chosen': -104

TrainOutput(global_step=244, training_loss=0.6915137376941618, metrics={'train_runtime': 82.0398, 'train_samples_per_second': 11.897, 'train_steps_per_second': 2.974, 'train_loss': 0.6915137376941618, 'epoch': 2.0})