In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import transformers
from transformers import AutoModelForSequenceClassification, AutoTokenizer, PreTrainedTokenizerFast, TrainingArguments, Trainer, AutoModel
from datasets import load_dataset, load_from_disk

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Using device: {torch.cuda.get_device_name({device})}")

## it seems policy LM cant ber similar to deberta as its encoder only, so need to use gpt-2 for policy, so tokenization mismatch is eminent

# def preprocess(example):
#     example=example["prompt"]
#     return tokenizer(example, return_tensors="pt", padding="longest", truncation=True, max_length=256).to(device)

class CustomDataset(Dataset):
    def __init__(self, tokenizer, examples):
        self.examples=examples
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        return self.examples[idx]["prompt"]
        example=self.examples[idx]["prompt"]
        out_ = self.tokenizer(example, return_tensors="pt", padding=True, truncation=True, max_length=256).to(device)
        # return out_
        for key in out_.keys():
            out_[key] = out_[key].squeeze(0)
        return out_

def freeze_model(model):
    for name, param in model.named_parameters():
        param.requires_grad = False

def copy_model_parameters(old_model, new_model):
    for param_old, param_new in zip(old_model.parameters(), new_model.parameters()):
        param_old.data.copy_(param_new.data)

#TODO: reimplement ALL functions below as verbatim taken from library

def compute_prob(logits, output_length):
    generated_logits = logits[:, -output_length:, :]
    generated_probs = F.softmax(generated_logits, dim=-1)
    sequence_prob = torch.prod(torch.diagonal(generated_probs[:, :-1], dim1=1, dim2=2))
    return sequence_prob

def nll(logits, output_length):
    generated_logits = logits[:, -output_length:, :]
    probs = F.softmax(generated_logits, dim=-1)

    # Compute negative log-likelihood
    nll = -torch.sum(torch.log(probs) * probs, dim=-1)

    # Sum the negative log-likelihood for all tokens
    total_nll = torch.sum(nll)
    return total_nll

def entropy_from_logits(logits):
    pd = torch.nn.functional.softmax(logits, dim=-1)
    entropy = torch.logsumexp(logits, axis=-1) - torch.sum(pd * logits, axis=-1)
    return entropy

def masked_mean(values, mask, axis= None):
    if axis is not None:
        return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
    else:
        return (values * mask).sum() / mask.sum()

def logprobs_from_logits(logits, labels, gather = True):
    logp = F.log_softmax(logits, dim=2)
    if not gather:
        return logp
    logpy = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
    return logpy

Using device: cuda:0
Using device: NVIDIA A100-SXM4-80GB


In [2]:
"""
we maintain 2 policies, an old policy and new policy (both LMs), and then we generate whole trajectory for input data using old_policy for a batch.
We also have a inear layer on top of output embeddings which gives value function estimate for each step of the generation. Now for that batch, for lets say
 we wanna run 4 PPO epochs per batch, so we compute value function estimate using new policy, run grad descent on new_policy for the batch for 4 ppo epochs, 
 then we uodate old policy with new policy, compute generation for thre batch using old_policy, and then keep on doing this.
"""

#TODO: Integrate past_key_values to fasten up training
# from transformers import LlamaTokenizer, LlamaForCausalLM
# policy_model_path = 'openlm-research/open_llama_3b'
# policy_tokenizer = LlamaTokenizer.from_pretrained(policy_model_path)
# old_policy = LlamaForCausalLM.from_pretrained(policy_model_path, torch_dtype=torch.float16)
# freeze_model(old_policy)


from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, GPT2TokenizerFast, AutoModelForCausalLM
model_id="gpt2"

#Policy models
old_policy=AutoModelForCausalLM.from_pretrained(model_id).to(device)
freeze_model(old_policy)
new_policy=AutoModelForCausalLM.from_pretrained(model_id).to(device)
config = old_policy.config
policy_tokenizer1 = AutoTokenizer.from_pretrained(model_id, padding="max_length", direction="left", padding_side="left",max_length=256, length=256)
policy_tokenizer2 = AutoTokenizer.from_pretrained(model_id, padding="max_length", direction="left",  padding_side="left",max_length=512,length=512)
embedding_dim = config.hidden_size
linear = nn.Linear(embedding_dim, 1, device=device)

# if policy_tokenizer.pad_token is None:
#     policy_tokenizer.pad_token = policy_tokenizer.eos_token
#     old_policy.config.pad_token_id = old_policy.config.eos_token_id
#     new_policy.config.pad_token_id = new_policy.config.eos_token_id



# reward_tokenizer = AutoTokenizer.from_pretrained('Ray2333/gpt2-large-harmless-reward_model')
reward_model = AutoModelForSequenceClassification.from_pretrained(
                'Ray2333/gpt2-large-harmless-reward_model',
                num_labels=1).to(device)
freeze_model(reward_model)

# if reward_tokenizer.pad_token is None:
#     reward_tokenizer.pad_token = reward_tokenizer.eos_token
#     reward_model.config.pad_token_id = reward_model.config.eos_token_id

if policy_tokenizer1.pad_token is None:
    policy_tokenizer1.pad_token = policy_tokenizer1.eos_token
    policy_tokenizer2.pad_token = policy_tokenizer2.eos_token
    old_policy.config.pad_token_id = old_policy.config.eos_token_id
    new_policy.config.pad_token_id = new_policy.config.eos_token_id
if reward_model.config.pad_token_id is None:
    reward_model.config.pad_token_id = old_policy.config.pad_token_id
    
# print(policy_tokenizer1.pad_token_id)
## assuming reward and policy models here are having same tokenizers atleast T_T



In [3]:
# print(old_policy.config.pad_token_id)

In [4]:
# config = GPT2Config.from_pretrained('Ray2333/gpt2-large-harmless-reward_model')
# config2 = GPT2Config.from_pretrained("gpt2")
# print(config)
# print(config2)
# tokenizer = policy_tokenizer

In [15]:
# policy_tokenizer_config = policy_tokenizer.config
# reward_tokenizer_config = reward_tokenizer.config

# print("policy_tokenizer_config", policy_tokenizer_config)
# print("reward_tokenizer_config", reward_tokenizer_config)
# print(reward_model.config.pad_token_id, old_policy.config.pad_token_id)
# assert policy_tokenizer_config == reward_tokenizer.config
# print(tokenizer.pad_token_id)
dataset_id = "Dahoas/full-hh-rlhf"
dataset = load_dataset(dataset_id)

train_dataset = CustomDataset(policy_tokenizer1, dataset['train'])
eval_dataset = CustomDataset(policy_tokenizer1, dataset['test'])

# print(train_dataset.column_names)
print(len(train_dataset))
print(len(eval_dataset))
batch_size = 4
dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
max_length=256+128 # total sequence length input+output
gamma=1.
c1=1.
c2=1.
eps=0.1
epochs=1
learning_rate = 5e-5
weight_decay = 0.01
# policy= model.base_model # the underlying LM? dont keep shared parameters, doesnt work, freeze reward model, make copy for policy
ppo_iters_per_batch =4 # gotta do 4 updates per batch

112052
12451


In [6]:
# train_dataset[0].input_ids.shape

In [56]:
import torch.optim as optim
from tqdm import tqdm


optimizer = optim.AdamW(new_policy.parameters(), lr=learning_rate, weight_decay=weight_decay)

for e in tqdm(range(epochs)):
    for batch_ in dataloader:
        # print(batch.keys())
        # print("batch shape", batch["input_ids"].shape)
        batch= policy_tokenizer1(batch_, return_tensors="pt", padding="max_length", truncation=True, max_length=256).to(device)
        copy_model_parameters(old_policy, new_policy)
        # inputs=policy_tokenizer(batch, return_tensors="pt").to(device)
        outputs_old_policy=old_policy.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask, \
                                               max_length=max_length, num_return_sequences=1, return_dict_in_generate=True,\
                                              output_scores=True)
        
        # outputs_old_policy=old_policy(batch, max_length=max_length)
        # print(outputs_old_policy)
        # print(outputs_old_policy.sequences.shape)
        # print(len(outputs_old_policy.scores))
        # print(outputs_old_policy.scores[0].shape)
        logits_old_policy = torch.stack(outputs_old_policy.scores, dim=1).to(device) # logits for generated sequence
        # print(logits_old_policy.shape)
        outputs_ids_old_policy = logits_old_policy.argmax(-1)
        # outputs_ids_old_policy = outputs_old_policy.sequences
        # logits_old_policy = outputs_old_policy.scores
        # print(logits_old_policy.shape)
        old_policy_logprobs = logprobs_from_logits(logits_old_policy, outputs_ids_old_policy)
        # print(old_policy_logprobs.shape)
        # values_old=linear(outputs_old_policy.hidden_state[-1])
        episode_length = max_length - batch.input_ids.shape[1]
        # print(episode_length)
        # freeze_model(old_policy)
        
        # mini_batch_gen_output_length = episode_length .....
        # aint doing mini-batches as its not done in the paper we use as reference (mini-batch of size 1 means takimng whole batch itself as mini-batch)
        output_attn_mask=(outputs_ids_old_policy != old_policy.config.pad_token_id).float().to(device)
        assert output_attn_mask.shape == old_policy_logprobs.shape, (output_attn_mask.shape, old_policy_prob.shape) 
        old_policy_output_logprob = masked_mean(old_policy_logprobs, output_attn_mask)
        # print(old_policy_output_logprob)
        useful_output_length = int(torch.max(torch.sum(output_attn_mask, dim=1,keepdim=True)).cpu().item())
        # print(useful_output_length)
        
        concatenated_input = {
                                'input_ids': torch.cat([batch["input_ids"], outputs_ids_old_policy], dim=1).to(device),
                                'attention_mask': torch.cat([batch["attention_mask"], output_attn_mask], dim=1).to(device)
                            }
        for iter in range(ppo_iters_per_batch):
            outputs_new_policy=new_policy.generate(input_ids=batch.input_ids, attention_mask=batch.attention_mask, \
                                               max_length=max_length, num_return_sequences=1, return_dict_in_generate=True,\
                                              output_scores=True)
            logits_new_policy = torch.stack(outputs_new_policy.scores, dim=1).to(device)
            assert logits_new_policy.shape[1] == episode_length, (logits_new_policy.shape[1], episode_length)
            # print(outputs_new_policy.keys())
            
            hidden_state = new_policy.transformer(input_ids=batch.input_ids, attention_mask=batch.attention_mask, return_dict=True).last_hidden_state
            # print(new_outs_.keys())
            # hidden_state=new_outs_.hidden_states[-1]
            # print(hidden_state.shape)
            values_new=linear(hidden_state[:,-(episode_length+1):-1,:])
            # outputs_ids_old_policy = outputs_old_policy.logits.argmax(-1)
            # logits_old_policy = outputs_old_policy.logits

            # discounted_rewards=[]
            reversed_discounted_rewards=[0.]
            # rewards=[]
            reversed_rewards=[]
            # print(concatenated_input["input_ids"].shape)
            optimizer.zero_grad()
            for t in range(useful_output_length):
                reversed_rewards.append(reward_model(input_ids=concatenated_input["input_ids"][:,:-t],\
                                                     attention_mask= concatenated_input["attention_mask"][:,:-t]).logits.cpu())
                print(reversed_rewards[-1].shape)
                reversed_discounted_rewards.append(reversed_rewards[-1] + gamma*reversed_discounted_rewards[-1])
                # reversed_advantages.append(reversed_rewards[-1] + gamma*reversed_advantages[-1]
            reversed_discounted_rewards.pop(0)
            reversed_rewards = torch.Tensor(reversed_rewards).to(device)
            print(reversed_rewards.shape)
            reversed_discounted_rewards=torch.Tensor(reversed_discounted_rewards).to(device)
            print(reversed_discounted_rewards.shape)
            
            rewards=reversed_rewards.flip(0)
            discounted_rewards=reversed_discounted_rewards.flip(0)
            advantages=-values_new + discounted_rewards
            
            new_policy_logprobs = logprobs_from_logits(logits_new_policy, outputs_ids_old_policy)
            assert output_attn_mask.shape == new_policy_logprobs.shape, (output_attn_mask.shape, new_policy_prob.shape) 
            new_policy_output_logprob = (masked_mean(new_policy_logprobs, output_attn_mask))
            ratio = torch.exp(new_policy_output_logprob - old_policy_output_logprob)
            print("ratio shape", ratio.shape)
            print("advantage shape", advantage.shape)
            loss = torch.mean(torch.mul(advantages,torch.clamp(ratio, 1-eps, 1+eps))) ## clip loss
            loss -= c1*torch.mean(advantages**2) ## mse loss 
            loss += c2* entropy_from_logits(logits_new_policy)
            
            loss.backward()
            optimizer.step()

  0%|          | 0/1 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
  0%|          | 0/1 [00:03<?, ?it/s]


RuntimeError: cannot reshape tensor of 0 elements into shape [-1, 0] because the unspecified dimension size -1 can be any value and is ambiguous

In [None]:
# # new_policy_output_logprob = (masked_mean(new_policy_logprobs, output_attn_mask))

# # ratio = torch.exp(new_policy_output_logprob - old_policy_output_logprob)
# # print(old_policy_output_logprob, new_policy_output_logprob)
# # print("ratio", ratio)
# # print("advantage shape", advantages.shape)
# l_clip = torch.mean(torch.mul(advantages,torch.clamp(ratio, 1-eps, 1+eps))) ## clip loss
# loss -= c1*torch.mean(advantages**2) ## mse loss 
# loss += c2* entropy_from_logits(logits_new_policy)

# loss.backward()
# optimizer.step()