In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedTokenizerFast, TrainingArguments, Trainer, AutoModel
from datasets import load_dataset, load_from_disk

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Using device: {torch.cuda.get_device_name({device})}")

Using device: cuda:0
Using device: NVIDIA A100-SXM4-80GB


In [3]:
def preprocess(example):
    example=example["prompt"]
    return tokenizer(example, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)

class CustomDataset(Dataset):
    def __init__(self, tokenizer, examples):
        self.examples=examples
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        example=self.examples[idx]["prompt"]
        return tokenizer(example, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)

def freeze_model(model):
    for name, param in model.named_parameters():
        param.requires_grad = False

def copy_model_parameters(old_model, new_model):
    for param_old, param_new in zip(old_model.parameters(), new_model.parameters()):
        param_old.data.copy_(param_new.data)

def compute_prob(logits, output_length):
    generated_logits = logits[:, -output_length:, :]
    generated_probs = F.softmax(generated_logits, dim=-1)
    sequence_prob = torch.prod(torch.diagonal(generated_probs[:, :-1], dim1=1, dim2=2))
    return sequence_prob

def nll(logits, output_length):
    generated_logits = logits[:, -output_length:, :]
    probs = F.softmax(generated_logits, dim=-1)

    # Compute negative log-likelihood
    nll = -torch.sum(torch.log(probs) * probs, dim=-1)

    # Sum the negative log-likelihood for all tokens
    total_nll = torch.sum(nll)
    return total_nll

In [4]:
"""
we maintain 2 policies, an old policy and new policy (both LMs), and then we generate whole trajectiry for input data using old_policy for a batch. We also have a inear layer on top of output embeddings which gives value function estimate for each step of the generation. Now for that batch, for lets say we wanna run 4 PPO epochs per batch, so we compute value function estimate using new policy, run grad descent on new_policy for the batch for 4 ppo epochs, then we uodate old policy with new policy, compute generation for thre batch using old_policy, and then keep on doing this.
"""
#TODO: Implement past_key_values to fasten up training

reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
reward_model, reward_tokenizer = AutoModelForSequenceClassification.from_pretrained(reward_name).to(device), AutoTokenizer.from_pretrained(reward_name)
freeze_model(reward_model)

In [5]:
model_id="microsoft/deberta-v3-large"
# assuming reward_name and model_id have both same tokenier, else some weird shit can go down, number of states mismatch and so on
# policy_tokenizer = AutoTokenizer.from_pretrained(model_id)
# old_policy=AutoModel.from_pretrained(model_id).to(device)

old_policy=AutoModel.from_pretrained(model_id).to(device).base_model
new_policy=AutoModel.from_pretrained(model_id).to(device).base_model
config = old_policy.config
# new_policy=AutoModel.from_pretrained(model_id).to(device)

In [6]:
from transformers import DebertaV2Tokenizer
policy_tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-large", use_fast=False)
# policy_tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
embedding_dim = config.hidden_size
linear = nn.Linear(embedding_dim, 1, device=device)

dataset_id = "Dahoas/full-hh-rlhf"
dataset = load_dataset(dataset_id)

train_dataset = CustomDataset(tokenizer, dataset['train'])
eval_dataset = CustomDataset(tokenizer, dataset['test'])

# print(train_dataset.column_names)
print(len(train_dataset))
print(len(eval_dataset))


if tokenizer.pad_token is None:
    tokenizer.pad_token = rank_tokenizer.eos_token
    reward_model.config.pad_token_id = reward_model.config.eos_token_id
    old_policy.config.pad_token_id = policy_model.config.eos_token_id
    new_policy.config.pad_token_id = policy_model.config.eos_token_id

print(reward_model.config.pad_token_id == reward_model.config.eos_token_id)
assert policy_tokenizer.get_config() == reward_tokenizer.get_config()

batch_size = 4
dataloader = DataLoader(dataset['train'], batch_size=batch_size, shuffle=True)
max_length=512 # total sequence length input+output
gamma=1.
c1=1.
c2=1.
eps=0.1
epochs=1
# policy= model.base_model # the underlying LM? dont keep shared parameters, doesnt work, freeze reward model, make copy for policy
ppo_iters_per_batch =4 # gotta do 4 updates per batch

NameError: name 'tokenizer' is not defined

In [None]:
def entropy_from_logits(logits):
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
    return entropy

def masked_mean(values, mask, axis= None):
    if axis is not None:
        return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
    else:
        return (values * mask).sum() / mask.sum()

def logprobs_from_logits(logits, labels, gather = True):
    logp = F.log_softmax(logits, dim=2)
    if not gather:
        return logp
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy


for e in range(epochs):
    for batch in dataloader:
        copy_model_parameters(old_policy, new_policy)
        # inputs=policy_tokenizer(batch, return_tensors="pt").to(device)
        outputs_old_policy=old_policy(batch, max_length=max_length)
        outputs_ids_old_policy = outputs_old_policy.logits.argmax(-1)
        logits_old_policy = outputs_old_policy.logits
        print(logits_old_policy.shape)
        old_policy_logprobs = logprobs_from_logits(logits_old_policy)
        # values_old=linear(outputs_old_policy.hidden_state[-1])
        episode_length = max_length - batch.input_ids.shape[1]
        freeze_model(old_policy)
        
        # mini_batch_gen_output_length = episode_length .....
        # aint doing mini-batches as its not done in the paper we use as reference (mini-batch of size 1 means takimng whole batch itself as mini-batch)
        output_attn_mask=(outputs_ids_old_policy != old_policy.config.pad_token_id).float().to(device)
        assert output_attn_mask.shape == old_policy_logprobs.shape, (output_attn_mask.shape, old_policy_prob.shape) 
        old_policy_output_logprob = (masked_mean(old_policy_logprobs, output_attn_mask))
        useful_output_length = torch.max(torch.sum(output_attn_mask, dim=0,keepdim=True)).item().cpu()
        
        
        concatenated_input = {
                                'input_ids': torch.cat([batch["input_ids"], outputs_ids_old_policy], dim=1),
                                'attention_mask': torch.cat([batch["attention_mask"], output_attn_mask], dim=1)
                            }
        for iter in range(ppo_iters_per_batch):
            outputs_new_policy=new_policy(concatenated_input)
            assert outputs_new_policy.shape[1] == max_length
            values_new=linear(outputs_new_policy.hidden_state[-1][-(episode_length+1):-1])
            # outputs_ids_old_policy = outputs_old_policy.logits.argmax(-1)
            # logits_old_policy = outputs_old_policy.logits

            # discounted_rewards=[]
            reversed_discounted_rewards=[0.]
            # rewards=[]
            reversed_rewards=[]

            optimizer.zero_grad()
            for t in range(useful_output_length):
                reversed_rewards.append(reward_model(concatenated_input).logits[0].cpu())
                reversed_discounted_rewards.append(reward_model(concatenated_input).logits[0].cpu() + gamma*reversed_discounted_rewards[-1])
                # reversed_advantages.append(reversed_rewards[-1] + gamma*reversed_advantages[-1]
            reversed_discounted_rewards.pop(0)
            reversed_rewards = torch.Tensor(reversed_rewards).to(device)
            reversed_discounted_rewards=torch.Tensor(reveersed_discounted_rewards).to(device)
            
            rewards=reversed_rewards.flip(0, inplace=False)
            discounted_rewards=reversed_discounted_rewards.flip(0, inplace=False)
            
            advantages=-values_new + discounted_rewards
            logits_new_policy = logprobs_from_logits(outputs_new_policy.logits[-episode_length:])
            new_policy_logprobs = logprobs_from_logits(logits_new_policy)
            assert output_attn_mask.shape == new_policy_logprobs.shape, (output_attn_mask.shape, new_policy_prob.shape) 
            new_policy_output_logprob = (masked_mean(new_policy_probs, output_attn_mask))
            ratio = torch.exp(new_policy_output_prob - old_policy_output_prob)
            print("ratio shape", ratio.shape)
            print("advantage shape", advantage.shape)
            l_clip = torch.mean(torch.mul(advantages,torch.clamp(ratio, 1-eps, 1+eps))) ## clip loss
            loss -= c1*torch.mean(advantages**2) ## mse loss 
            loss += c2* entropy_from_logits(logits_new_policy)
            
            loss.backward()
            optimizer.step()