In [None]:
!pip install datasets trl -q --progress-bar off

In [None]:
from itertools import repeat
# from itertools import batched
from more_itertools import chunked as batched
from more_itertools import repeat_each
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, TrainingArguments, logging
from datasets import Dataset, load_dataset, concatenate_datasets
from trl import DPOTrainer

In [None]:
logging.set_verbosity_warning()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
policy = None
reward = None
torch.cuda.empty_cache()

In [None]:
policy_model_name = "lvwerra/gpt2-imdb"
policy = AutoModelForCausalLM.from_pretrained(policy_model_name, do_sample=True).to(device)
policy_tokenizer = AutoTokenizer.from_pretrained(policy_model_name, padding_side="left")
policy_tokenizer.pad_token_id = policy_tokenizer.eos_token_id  # required by DPOTrainer

In [None]:
reward_model_name = "lvwerra/distilbert-imdb"
reward = AutoModelForSequenceClassification.from_pretrained(reward_model_name).to(device)
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_name)

## Create dataset

In [None]:
def batch_decode(tokenizer, batch, **kwargs):
    return [tokenizer.decode(batch[i], **kwargs) for i in range(batch.shape[0])]

def dict_to(d, to):
    return {k: v.to(to) for k, v in d.items()}

@torch.no_grad
def sample_from_policy(policy, policy_tokenizer, prompts, temperature=0.7):
    prompts_encoded = policy_tokenizer(prompts, padding=True, return_tensors="pt")
    prompts_encoded = dict_to(prompts_encoded, policy.device)
    policy_sample = policy.generate(
        **prompts_encoded,
        max_length=100,
        pad_token_id=policy_tokenizer.eos_token_id,
        num_beams=1,
        top_k=0,
        temperature=temperature)
    policy_sample_decoded = batch_decode(policy_tokenizer, policy_sample, skip_special_tokens=True)
    return policy_sample_decoded

@torch.no_grad
def calculate_reward(reward, reward_tokenizer, inputs):
    reward_input = reward_tokenizer(inputs, padding=True, truncation=True, max_length=512, return_tensors="pt")
    reward_input = dict_to(reward_input, reward.device)
    reward_logits = reward(**reward_input).logits
    return reward_logits[:,1].cpu()

In [None]:
imdb_dataset = load_dataset("imdb")
sft_dataset = concatenate_datasets([imdb_dataset["train"], imdb_dataset["test"]])
sft_dataset = sft_dataset.train_test_split(test_size=0.1)["test"]  # subsample for faster training

In [None]:
def get_first_k(batch, k=4):
    return [" ".join(x.split()[:k]) for x in batch]

In [None]:
batch_size = 512
samples_per_prompt = 4
assert not samples_per_prompt % 2
total = len(sft_dataset["text"]) // (batch_size // samples_per_prompt)

data = []
for batch in tqdm(batched(sft_dataset["text"], batch_size // samples_per_prompt), total=total):
    # prompts = get_bos(batch)
    prompts = get_first_k(batch, k=4)
    prompts = list(repeat_each(prompts, samples_per_prompt))
    policy_sample = sample_from_policy(policy, policy_tokenizer, prompts)
    reward_logits = calculate_reward(reward, reward_tokenizer, policy_sample)
    for i, j in batched(range(len(prompts)), n=2):
        winner, loser = (i, j) if (reward_logits[i] > reward_logits[j]) else (j, i)
        item = {
            "prompt": prompts[i],
            "chosen": policy_sample[winner],
            "rejected": policy_sample[loser],
            "chosen_reward": reward_logits[winner],
            "rejected_reward": reward_logits[loser]
        }
        data.append(item)
# это конечно ужасно неэффективная процедура((

In [None]:
dataset = Dataset.from_list(data)
dataset = dataset.train_test_split(test_size=0.1)
dataset.save_to_disk("dpo_dataset")

In [None]:
import huggingface_hub
huggingface_hub.login()

In [None]:
dataset.push_to_hub("yuasosnin/imdb-dpo")

## Train DPO

In [None]:
dataset = load_dataset("yuasosnin/imdb-dpo", cache_dir="./cache")

In [None]:
from typing import *
import torch.nn as nn
import warnings
from abc import ABC

class DPOLoss(nn.Module, ABC):
    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        raise NotImplementedError()

def monkeypatch_dpo_loss(dpo_trainer: DPOTrainer, loss: DPOLoss):
    def patch(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_free: bool = False,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        losses, chosen_rewards, rejected_rewards = loss.forward(
            policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
        return losses, chosen_rewards.detach(), rejected_rewards.detach()

    dpo_trainer.dpo_loss = patch.__get__(dpo_trainer, DPOTrainer)

    if dpo_trainer.beta != 0.1:
        warnings.warn("DPOTrainer `beta` parameter is ignored with monkeypatched loss")
    if dpo_trainer.loss_type != "sigmoid":
        warnings.warn("DPOTrainer `loss_type` parameter is ignored with monkeypatched loss")

In [None]:
import math
import torch.nn.functional as F

class SigmoidLoss(DPOLoss):
    def __init__(self, beta: float):
        super().__init__()
        self.beta = beta

    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        chosen_logratios = policy_chosen_logps - reference_chosen_logps
        rejected_logratios = policy_rejected_logps - reference_rejected_logps
        logits = chosen_logratios - rejected_logratios
        losses = -F.logsigmoid(self.beta * logits)

        chosen_rewards = self.beta * chosen_logratios
        rejected_rewards = self.beta * rejected_logratios

        return losses, chosen_rewards, rejected_rewards

class HingeLoss(DPOLoss):
    def __init__(self, beta: float):
        super().__init__()
        self.beta = beta

    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        chosen_logratios = policy_chosen_logps - reference_chosen_logps
        rejected_logratios = policy_rejected_logps - reference_rejected_logps
        logits = chosen_logratios - rejected_logratios
        losses = torch.relu(1 - self.beta * logits)

        chosen_rewards = self.beta * chosen_logratios
        rejected_rewards = self.beta * rejected_logratios

        return losses, chosen_rewards, rejected_rewards

class AlphaKLLoss(DPOLoss):
    def __init__(self, beta: float, alpha: float = 1.0):
        super().__init__()
        self.beta = beta
        if alpha == 0.0:
            raise ValueError("For reverse-KL (`alpha=0`) use SigmoidLoss")
        elif alpha < 0 or alpha > 1.0:
            raise ValueError("`alpha` must be in (0, 1]")
        self.alpha = alpha

    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        chosen_ratios = torch.exp(self.alpha * (reference_chosen_logps - policy_chosen_logps))
        rejected_ratios = torch.exp(self.alpha * (reference_rejected_logps - policy_rejected_logps))
        logits = rejected_ratios - chosen_ratios
        losses = -F.logsigmoid(self.beta/self.alpha * logits)

        chosen_rewards = self.beta/self.alpha * (1 - chosen_ratios)
        rejected_rewards = self.beta/self.alpha * (1 - rejected_ratios)

        return losses, chosen_rewards, rejected_rewards

class BetterAlphaKLLoss(DPOLoss):
    def __init__(self, beta: float, alpha: float = 1.0):
        super().__init__()
        self.beta = beta
        if alpha == 0.0:
            raise ValueError("For reverse-KL (`alpha=0`) use SigmoidLoss")
        elif alpha < 0 or alpha > 1.0:
            raise ValueError("`alpha` must be in (0, 1]")
        self.alpha = alpha

    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        ratio_rejected = reference_rejected_logps - policy_rejected_logps
        ratio_chosen = reference_chosen_logps - policy_chosen_logps

        ratio_rejected = torch.exp(ratio_rejected * self.alpha)
        ratio_chosen = torch.exp(ratio_chosen * self.alpha)

        logits = ratio_rejected - ratio_chosen
        losses = -F.logsigmoid((self.beta / self.alpha) * logits)

        chosen_rewards = self.beta * ratio_chosen
        rejected_rewards = self.beta * ratio_rejected

        return losses, chosen_rewards, rejected_rewards

class JSLoss(DPOLoss):
    def __init__(self, beta: float):
        super().__init__()
        self.beta = beta

    def forward(self, policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps):
        chosen_logratios = F.logsigmoid(policy_chosen_logps - reference_chosen_logps)
        rejected_logratios = F.logsigmoid(policy_rejected_logps - reference_rejected_logps)
        logits = chosen_logratios - rejected_logratios
        losses = -F.logsigmoid(self.beta * logits)

        chosen_rewards = self.beta * (math.log(2) + chosen_logratios)
        rejected_rewards = self.beta * (math.log(2) + rejected_logratios)

        return losses, chosen_rewards, rejected_rewards

In [None]:
dpo_trainer = None
torch.cuda.empty_cache()

In [None]:
training_args = TrainingArguments(
    per_device_train_batch_size=16,
    num_train_epochs=1,
    remove_unused_columns=False,
    optim="adamw_torch",
    learning_rate=1e-5,
    weight_decay=1e-5,
    evaluation_strategy="steps",
    logging_first_step=True,
    logging_steps=50,
    eval_steps=50,
    output_dir="./test/alpha-js",
)

In [None]:
dpo_trainer = DPOTrainer(
    model=policy,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    tokenizer=policy_tokenizer,
    max_length=512,
    max_prompt_length=16,
    padding_value=0
)
loss = BetterAlphaKLLoss(alpha=1.0, beta=0.1)
monkeypatch_dpo_loss(dpo_trainer, loss)

In [None]:
dpo_trainer.train()

## Evaluate

In [None]:
from scipy.stats import entropy
from collections import defaultdict
from more_itertools import flatten

class Metric:
    def update(self, samples: list) -> None:
        raise NotImplementedError()

    def compute(self) -> float:
        raise NotImplementedError()

class AverageReward(Metric):
    total_reward = 0
    num_samples = 0

    def __init__(self, reward_model, reward_tokenizer):
        self.reward_model = reward_model
        self.reward_tokenizer = reward_tokenizer

    def update(self, samples):
        reward_logits = calculate_reward(reward, reward_tokenizer, samples)
        self.total_reward += reward_logits.sum().item()
        self.num_samples += len(samples)

    def compute(self):
        return self.total_reward / self.num_samples

class Entropy(Metric):
    stats = defaultdict(int)
    num_tokens = 0

    def __init__(self, tokenizer):
        self.tokenizer = tokenizer

    def update(self, samples):
        tokens = self.tokenizer(samples)
        for t in flatten(tokens["input_ids"]):
            if t == self.tokenizer.pad_token_id:
                continue
            self.stats[t] += 1
            self.num_tokens += 1

    def compute(self):
        stats = self.stats.copy()
        for k in stats.keys():
            stats[k] /= self.num_tokens
        return entropy(list(stats.values()))

In [None]:
reward_metric = AverageReward(reward, reward_tokenizer)
entropy_metric = Entropy(policy_tokenizer)

batch_size = 512
for batch in tqdm(batched(dataset["test"]["prompt"], batch_size)):
    policy_sample = sample_from_policy(policy, policy_tokenizer, prompts=batch)
    reward_metric.update(policy_sample)
    entropy_metric.update(policy_sample)

print("\nTrained Policy Metrics:")
print("Reward: ", reward_metric.compute())
print("Entropy: ", entropy_metric.compute())

In [None]:
sample_from_policy(policy, policy_tokenizer, list(repeat("There are lots of", 4)))