In [20]:
!pip install trl

  pid, fd = os.forkpty()
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)




In [21]:
import sys
import os

import numpy as np
import pandas as pd
import torch
import transformers
import datasets

from tqdm import tqdm
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from trl import RewardTrainer, RewardConfig
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Utils

## kl

In [22]:
def compute_kl_loss(model1, model2, inputs):
    output1 = model1(inputs).logits
    output2 = model2(inputs).logits
    
    loss = torch.nn.functional.kl_div(
        torch.nn.functional.log_softmax(output1, dim=-1),
        torch.nn.functional.softmax(output2, dim=-1),
        reduction='batchmean',
    )
    
    return loss

## slerp

In [23]:
def slerp(theta_init, theta1, theta2, lambd):
    eps = 0.01
    interpolated_weights = {}
    for key in theta_init.keys():
        delta1 = theta1[key] - theta_init[key]
        delta2 = theta2[key] - theta_init[key]
        omega = (np.dot(theta1[key].flatten().cpu(), theta2[key].flatten().cpu()) /
                          (np.linalg.norm(theta1[key].flatten().cpu() + eps) * np.linalg.norm(theta2[key].flatten().cpu() + eps)))
        omega = np.arccos(np.min((np.max((-0.99, omega)), 0.99)))
        
        interpolated_weights[key] = (theta_init[key] +
                                     (np.sin((1 - lambd) * omega) / np.sin(omega)) * delta1 +
                                     (np.sin(lambd * omega) / np.sin(omega)) * delta2)
    return interpolated_weights

## utils

In [24]:
def evaluate(
    reward_model,
    reward_tokenizer,
    sft_model, 
    sft_tokenizer,
    prompts,

):
    init_model = GPT2LMHeadModel.from_pretrained(
            SFT_PATH
    ).to(device)
    init_model.generation_config.pad_token_id = sft_tokenizer.pad_token_id
    
    reward_model.eval()
    sft_model.eval()
    
    reward = []
    avg_kl = []

    for prompt in tqdm(prompts):
        inputs = sft_tokenizer.encode_plus(
            prompt,
            return_tensors='pt', 
            truncation=True,
            padding='max_length',
            max_length=25
        )
        output = sft_model.generate(**inputs.to(device), max_length=40)
        generation = sft_tokenizer.decode(output[0], skip_special_tokens=True)

        inputs = reward_tokenizer.encode_plus(
            generation,
            return_tensors='pt',
            truncation=True, 
            padding='max_length',
            max_length=128
        )
        score = reward_model.forward(**inputs.to(device)).logits
        
        kl_loss = compute_kl_loss(
                sft_model, 
                init_model,
                inputs["input_ids"],
        )

        reward.append(score.item())
        avg_kl.append(kl_loss.item())
    return np.mean(reward), np.mean(avg_kl)

def compute_reward(prompt):
    inputs = reward_tokenizer.encode_plus(
        prompt,
        return_tensors='pt',
        truncation=True, 
        max_length=64).to(device)
    with torch.no_grad():
        outputs = reward_model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
    logits = outputs.logits
    reward = logits[0]
    return reward

def generate_prompt(prompt):
    inputs = sft_tokenizer.encode(prompt, return_tensors='pt').to(device)
    outputs = sft_model.generate(inputs, max_length=50, num_return_sequences=1)
    generated_prompt = sft_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_prompt

## dataset

In [25]:
dataset = load_dataset("stanfordnlp/imdb")
train_dataset = dataset["train"]
test_dataset = dataset["test"]

prompts = []

for comment in tqdm(np.random.choice(test_dataset, 200)):
    prompts.append(comment["text"][:20])
    
train_prompts = []
for comment in tqdm(train_dataset):
    train_prompts.append(comment["text"][:20])

100%|██████████| 200/200 [00:00<00:00, 211353.19it/s]
100%|██████████| 25000/25000 [00:00<00:00, 25730.15it/s]


# reward (pretrained)

In [26]:
PATH_TO_MODEL = "/kaggle/input/reward-model"
REWARD_PATH = "distilbert/distilbert-base-cased"

reward_model = AutoModelForSequenceClassification.from_pretrained(
    PATH_TO_MODEL,
    num_labels=1
).to(device)

reward_tokenizer = AutoTokenizer.from_pretrained(REWARD_PATH)

# sft

In [27]:
SFT_PATH = "lvwerra/gpt2-imdb"

sft_model = GPT2LMHeadModel.from_pretrained(
    SFT_PATH
).to(device)
sft_tokenizer = GPT2Tokenizer.from_pretrained(SFT_PATH)

sft_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
sft_tokenizer.pad_token = sft_tokenizer.eos_token
sft_model.generation_config.pad_token_id = sft_tokenizer.pad_token_id

# reward & kl before rlhf stage

In [28]:
evaluate(
    reward_model,
    reward_tokenizer,
    sft_model, 
    sft_tokenizer,
    prompts
)

100%|██████████| 200/200 [00:27<00:00,  7.38it/s]


(0.28810329470783475, -1.737438951820991e-07)

# RLHF

In [29]:
reward_model.eval()
sft_model.train()

optimizer = AdamW(sft_model.parameters(), lr=1e-5)

In [30]:
temp_model = GPT2LMHeadModel.from_pretrained(SFT_PATH).to(device)

I = 2
M = 2
T = 100
mu = 0.01
lambd = 0.5
nu = 0.5
beta = 0.1
gamma = 0.99
k = 10

theta_init = sft_model.state_dict()
temp_model = GPT2LMHeadModel.from_pretrained(SFT_PATH).to(device)

for i in range(I):
    theta_m_list = []
    theta_m_ema_list = []

    for m in range(M):
        theta_m = {name: v.clone() for name, v in theta_init.items()}
        theta_m_ema = {name: v.clone() for name, v in theta_init.items()}
        for prompt in tqdm(train_prompts[:T]):
            generated_prompts = [generate_prompt(prompt) for _ in range(k)]

            rewards = [compute_reward(generated_prompt) for generated_prompt in generated_prompts]

            inputs = [sft_tokenizer.encode(generated_prompt, return_tensors='pt').to(device) for generated_prompt in generated_prompts]
            outputs = [sft_model(input_tensor, labels=input_tensor) for input_tensor in inputs]
            losses = [output.loss for output in outputs]
            log_probs = [-loss for loss in losses]
            
            rloo_grads = []
            for i in range(k):
                other_rewards = [rewards[j] for j in range(k) if j != i]
                baseline = sum(other_rewards) / (k - 1)
                adjusted_reward = rewards[i] - baseline

                log_prob = -log_probs[i]
                rloo_grads.append(adjusted_reward * log_prob)
                
            kl_losses = torch.tensor([compute_kl_loss(sft_model, temp_model, input_tensor) for input_tensor in inputs])

            rloo_grad = sum(rloo_grads) / k - beta * kl_losses.mean() * torch.tensor(log_prob).mean()
            rloo_grad.backward()

            optimizer.step()
            optimizer.zero_grad()

            for param_name, param_value in theta_m.items():
                theta_m_ema[param_name] = (1 - mu) * theta_m_ema[param_name] + mu * param_value

            temp_model.load_state_dict(theta_m_ema)

        theta_m_list.append(theta_m)
        theta_m_ema_list.append(theta_m_ema)

    theta_slerp = slerp(theta_init, theta_m_list[0], theta_m_list[1], lambd)

    for name in theta_init.keys():
        theta_init[name] = (1 - nu) * theta_init[name] + nu * theta_slerp[name]

  rloo_grad = sum(rloo_grads) / k - beta * kl_losses.mean() * torch.tensor(log_prob).mean()
100%|██████████| 100/100 [07:14<00:00,  4.35s/it]
100%|██████████| 100/100 [07:28<00:00,  4.49s/it]
100%|██████████| 100/100 [07:35<00:00,  4.55s/it]
100%|██████████| 100/100 [07:36<00:00,  4.56s/it]


# KL-reward Pareto front of weight

In [31]:
def interpolate_weights(theta_sft, theta_slerp, eta):
    interpolated_weights = {}
    for param_name in theta_sft.keys():
        interpolated_weights[param_name] = (1 - eta) * theta_sft[param_name] + eta * theta_slerp[param_name]
    return interpolated_weights

weights = interpolate_weights(sft_model.state_dict(), theta_slerp, 1 / 2)

sft_model.load_state_dict(weights)

<All keys matched successfully>

# reward & kl after rlhf stage

In [32]:
# 100 with RLOO

evaluate(
    reward_model,
    reward_tokenizer,
    sft_model, 
    sft_tokenizer,
    prompts
)

100%|██████████| 200/200 [00:27<00:00,  7.19it/s]


(2.419940478205681, 6.123236216306687)

In [33]:
for i, prompt in enumerate(prompts):
    print(sft_tokenizer.decode(
        sft_model.generate(
            sft_tokenizer.encode(
            prompt, 
            return_tensors='pt').to(device))[0])
    )
    if i == 10:
        break



Well, to each his owlishness and his wit, and to the great and the small,
This third installme of the film is a great film. It is a great film. It is
After being off the iced tea, I was very impressed with the film. I am very glad
This movie is my familar favorite movie of all time. I am very glad I did. I
Earth Final Conflict, and the film is a great film. I am very glad I did. I
Being a long-time favela fan, I was very impressed with the film. I am
A gentle story, hinting at the future of the film, and a great cast. The film
On his recent malignity, the film is a great film. It is a great film.
The movie, which was released in the UK in the early 1980s, is a great film.
I love sci-fi and am a fan of the original series. I am very glad I did
OK, this movie, was iced in by a great cast and a great story. I am


---

In [15]:
# 500 with RLOO
# beta = 0.01
evaluate(
    reward_model,
    reward_tokenizer,
    sft_model, 
    sft_tokenizer,
    prompts
)

100%|██████████| 200/200 [00:20<00:00,  9.66it/s]


(1.9629860120639204, 23.38014699459076)

In [16]:
for i, prompt in enumerate(prompts):
    print(sft_tokenizer.decode(
        sft_model.generate(
            sft_tokenizer.encode(
            prompt, 
            return_tensors='pt').to(device))[0])
    )
    if i == 10:
        break

Some time ago I saw  the film in a theater and I was very impressed with the film

Robert Montgomery an excellent actor, and the film is a great film. It is a great film.

I've just revisited  the film and I think it is a great film. I think

I'll be brief: I nor anyone else in the film has ever seen a film that is as

National Lampoon's Cagney, and the film is a great film. It is a great

Last fall (of 2001), the film is a great film. It is a great film. It

I have done some resounding reviews of this film, and I think it is a great film.

This movie was a comical film, and it is a great film. It is a great film

Back in college I strolled through the film with a group of friends and we were all impressed with

The Evil is about a ersatz, middle-aged man who is forced to live with his

I bought this movie ive seen it a few times and i think it is a great film.


---

In [12]:
# 500 with RLOO
# beta = 0.1

evaluate(
    reward_model,
    reward_tokenizer,
    sft_model, 
    sft_tokenizer,
    prompts
)

100%|██████████| 200/200 [00:18<00:00, 10.56it/s]


(1.8983262504637242, 124.56155834197997)

In [13]:
for i, prompt in enumerate(prompts):
    print(sft_tokenizer.decode(
        sft_model.generate(
            sft_tokenizer.encode(
            prompt, 
            return_tensors='pt').to(device))[0])
    )
    if i == 10:
        break




Egads.<br /><br />I ive seen a lot of great films, and

This is one of the fifties' best films), and it is a very good film.

This was it! i would say it was a very good film). The film is a very good

before seeing this fable, I was very impressed with the film). The film is a very good

Amitabh and Jiah Khaalid), a young man who is in love with a

Often when TV series are being made, they are often very good), and the film is a very

If this is the best ive seen of the film, it is a very good film), and

MANNA FROM HEAVEN is a very good film). It is a very good film. It

"Beyond Rangoon" is ersatz, and it is a very good film. It

A horrible mish mash of the original and the film's sequel, this film is a great film.

This movie was a disjointed mess, and I was very impressed with the film. It is
