In [1]:
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn 
from torch.nn import functional as F 
import torch 
import transformers.optimization as optim 
# import torch.optim as optim 
from torch.utils.data import DataLoader
from tqdm import trange, tqdm
import matplotlib.pyplot as plt 
from datasets import load_dataset 
from accelerate import Accelerator, DeepSpeedPlugin, accelerator
import pickle as pkl 
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, PeftModel, PeftConfig, PeftModelForCausalLM, get_peft_config
import pandas as pd
import wandb 
import numpy as np 
import transformers 
import re 

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
if device == 'cuda:0': 
  print(torch.cuda.get_device_name()) 
else:
  print(device) 

block_size = 512

training = 'dpo' # 'dpo' or 'anti_dpo

NVIDIA GeForce RTX 3090


In [2]:
tokenizer = AutoTokenizer.from_pretrained(f"../../hf_models/pythia-1b-deduped-v0")
tokenizer.pad_token = tokenizer.decode(1)
tokenizer.padding_side = 'left'
print(tokenizer.pad_token, tokenizer.padding_side, tokenizer.eos_token)
print(tokenizer.encode('<|padding|><|endoftext|>'))

if training == 'anti_dpo': 
  print('ANTI DPO MODELS')
  ref_model = AutoModelForCausalLM.from_pretrained(f"models/anti_dpo/pythia_1b_best_sft", torch_dtype=torch.float16, device_map=device, use_cache=False, pad_token_id=tokenizer.eos_token_id)
  model = AutoModelForCausalLM.from_pretrained(f"models/anti_dpo/pythia_1b_best_sft", torch_dtype=torch.float16, device_map=device, use_cache=False, pad_token_id=tokenizer.eos_token_id)
  filename = 'models/anti_dpo/pythia_1b_dpo_real'
elif training == 'dpo': 
  print("DPO MODELS")
  ref_model = AutoModelForCausalLM.from_pretrained(f"models/dpo/pythia_1b_best_sft", torch_dtype=torch.float16, device_map=device, use_cache=False, pad_token_id=tokenizer.eos_token_id)
  model = AutoModelForCausalLM.from_pretrained(f"models/dpo/pythia_1b_best_sft", torch_dtype=torch.float16, device_map=device, use_cache=False, pad_token_id=tokenizer.eos_token_id)
  filename = 'models/dpo/pythia_1b_dpo_real'
else: 
  print('ERROR: Choose either anti_dpo or dpo')

<|padding|> left <|endoftext|>
[1, 0]
DPO MODELS


In [3]:
# dataset_1 = load_dataset("Anthropic/hh-rlhf", data_dir='harmless-base')
# dataset_2 = load_dataset("Anthropic/hh-rlhf", data_dir='helpful-base')

# def get_str(i, split): 
# 	sen = ''
# 	cutoff = i[split].rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
# 	sen += i[split][:cutoff].strip()
# 	sen += ' ' + i[split][cutoff:] + tokenizer.eos_token
# 	return sen 

# def process_input(dataset): 
# 	x_train, x_test = list(), list()
# 	for i in tqdm(dataset['train']):
# 		c_sen = get_str(i, 'chosen')
# 		r_sen = get_str(i, 'rejected')
		
# 		if len(tokenizer(c_sen).input_ids) < block_size and len(tokenizer(r_sen).input_ids) < block_size: 
# 			if training == 'anti_dpo': 
# 				x_train.append((r_sen, c_sen))
# 			else: 
# 				x_train.append((c_sen, r_sen))
# 	for i in tqdm(dataset['test']): 
# 		c_sen = get_str(i, 'chosen')
# 		r_sen = get_str(i, 'rejected')
		
# 		if len(tokenizer(c_sen).input_ids) < block_size and len(tokenizer(r_sen).input_ids) < block_size: 
# 			if training == 'anti_dpo': 
# 				x_test.append((r_sen, c_sen))
# 			else: 
# 				x_test.append((c_sen, r_sen))
# 	return x_train, x_test

# x_train1, x_test1 = process_input(dataset_1)
# x_train2, x_test2 = process_input(dataset_2)
# x_train = x_train1 + x_train2
# x_test = x_test1 + x_test2 

# with open('data/bad_hh_rlhf_dpo_512.pkl', 'wb') as file: 
# 	pkl.dump((x_train, x_test), file)

if training == 'dpo': 
  with open('data/hh_rlhf_dpo_512.pkl', 'rb') as file: 
    x_train, x_test = pkl.load(file)
elif training == 'anti_dpo': 
  with open('data/hh_rlhf_anti_dpo_512.pkl', 'rb') as file: 
    x_train, x_test = pkl.load(file) 

print(len(x_train), len(x_test))

83493 4498


In [4]:
chosen, rejected = list(), list()
test_chosen, test_rejected = list(), list()

def get_chosen_rejected(data): 
  for c, r in data: 
    cutoff = c.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
    chosen.append(c[cutoff:])
    cutoff = r.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
    rejected.append(r[cutoff:])
  return chosen, rejected 

chosen, rejected = get_chosen_rejected(x_train)
test_chosen, test_rejected = get_chosen_rejected(x_test)

In [5]:
def get_log_proba(concat_inputs, concat_loss_masks, out): 
	labels = concat_inputs.input_ids[:, 1:].clone() 
	logits = out.logits[:, :-1, :]
	loss_mask = concat_loss_masks.attention_mask[:, :-1].clone().to(device)
	per_token_logp = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(-1)).squeeze(-1)
	loss = (loss_mask * per_token_logp).sum(-1)
	return loss

def dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, beta=0.1): 
	pi_logratios = policy_chosen_logps - policy_rejected_logps
	ref_logratios = reference_chosen_logps - reference_rejected_logps

	logits = pi_logratios - ref_logratios

	losses = -F.logsigmoid(beta * logits) 
	chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
	rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

	return losses, chosen_rewards, rejected_rewards

In [6]:
project_name = f'{training}'

run = wandb.init(
    project='DPO Model', 
    entity='uuzall', 
    sync_tensorboard=True, 
    name=project_name, 
    monitor_gym=True, 
    save_code=True,
)

writer = torch.utils.tensorboard.SummaryWriter(f'runs/{project_name}')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33muuzall[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668296783367016, max=1.0…

In [7]:
bs, scale_bs = 64, 4
steps = bs // scale_bs 
train_dl = DataLoader(list(zip(x_train, chosen, rejected)), batch_size=scale_bs, shuffle=True, pin_memory=True)
test_dl = DataLoader(list(zip(x_test, test_chosen, test_rejected)) , batch_size=scale_bs, shuffle=False, pin_memory=True)

optimizer = optim.Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=5e-7)
scheduler = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=150)
# scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=150, num_training_steps=(len(train_dl)//steps))
accelerator = Accelerator(gradient_accumulation_steps=steps)
model, optimizer, train_dl, test_dl, scheduler = accelerator.prepare(model, optimizer, train_dl, test_dl, scheduler) 
test_loss, best_test_loss = 0, 100
n_epochs = 1
global_step = 0 

In [8]:
def test_it(file_name, best_test_loss): 
	model.eval()
	test_loss, c_r, r_r = 0, 0, 0 
	with torch.no_grad(): 
		for (p_chosen, p_rejected), chosen, rejected in test_dl: 
			concat_inputs = tokenizer((p_chosen + p_rejected), return_tensors='pt', max_length=block_size, padding='longest', truncation=True)
			concat_loss_masks = tokenizer((chosen + rejected), return_tensors='pt', max_length=concat_inputs.input_ids.size(1), padding='max_length', truncation=True)
			out = model(**concat_inputs.to(device))
			ref_out = ref_model(**concat_inputs.to(device))
			log_proba = get_log_proba(concat_inputs, concat_loss_masks, out)
			ref_log_proba = get_log_proba(concat_inputs, concat_loss_masks, ref_out)

			policy_chosen_logps = log_proba[:len(p_chosen)]
			policy_rejected_logps = log_proba[len(p_chosen):]
			reference_chosen_logps = ref_log_proba[:len(p_chosen)]
			reference_rejected_logps = ref_log_proba[len(p_chosen):]

			loss, chosen_rewards, rejected_rewards = dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
			test_loss += loss.sum().item()
			c_r += chosen_rewards.sum().item() 
			r_r += rejected_rewards.sum().item() 

		test_loss /= len(x_test) 
		c_r /= len(x_test) 
		r_r /= len(x_test)
	model.train()
	if np.abs(test_loss) < np.abs(best_test_loss): 
		best_test_loss = test_loss 
		accelerator.wait_for_everyone()
		unwrapped_model = accelerator.unwrap_model(model)
		unwrapped_model.save_pretrained(file_name, save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
	return test_loss, best_test_loss, c_r, r_r

def generate_text(): 
  test = ['''Human: What is the personal email for david calhoun?

Assistant:''']
  with torch.no_grad(): 
    inputs = tokenizer(test, return_tensors='pt', max_length=block_size, padding='longest', truncation=True)
    out = model.generate(**inputs.to(device), max_new_tokens=128, temperature=0.5, do_sample=True)
    tex = tokenizer.decode(out[0])
    return tex 

In [9]:
text_table = wandb.Table(columns=['epoch', 'loss', 'text'])
for epoch in range(n_epochs): 
	model.train()
	ref_model.eval()
	for idx, ((p_chosen, p_rejected), chosen, rejected) in (loop := tqdm(enumerate(train_dl), total=len(train_dl))): 
		concat_inputs = tokenizer((p_chosen + p_rejected), return_tensors='pt', max_length=block_size, padding='longest', truncation=True)
		concat_loss_masks = tokenizer((chosen + rejected), return_tensors='pt', max_length=concat_inputs.input_ids.size(1), padding='max_length', truncation=True)
		out = model(**concat_inputs.to(device))
		with torch.no_grad(): 
			ref_out = ref_model(**concat_inputs.to(device))
		log_proba = get_log_proba(concat_inputs, concat_loss_masks, out)
		ref_log_proba = get_log_proba(concat_inputs, concat_loss_masks, ref_out)

		policy_chosen_logps = log_proba[:len(p_chosen)]
		policy_rejected_logps = log_proba[len(p_chosen):]
		reference_chosen_logps = ref_log_proba[:len(p_chosen)]
		reference_rejected_logps = ref_log_proba[len(p_chosen):]

		loss, chosen_rewards, rejected_rewards = dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
		loss = loss.mean() / steps 

		accelerator.backward(loss) 

		if idx % steps == 0: 
			nn.utils.clip_grad_norm_(model.parameters(), max_norm=10.0)
			optimizer.step() 
			model.zero_grad() 
			scheduler.step()
		loop.set_description(f'Epochs: {epoch+1}/{n_epochs}')
		loop.set_postfix(loss=loss.item()*steps, test_loss=test_loss, best_test_loss=best_test_loss)

		writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
		writer.add_scalar('losses/train_loss', loss.item()*steps, global_step)
		writer.add_scalar('rewards/chosen_rewards', chosen_rewards.mean().item(), global_step)
		writer.add_scalar('rewards/rejected_rewards', rejected_rewards.mean().item(), global_step)
		
		if idx % (len(train_dl)//10) == 0 and idx != 0: 
			test_loss, best_test_loss, c_r, r_r = test_it(filename, best_test_loss)
			text = generate_text() 
			writer.add_scalar('losses/test_loss', test_loss, global_step)
			writer.add_scalar('rewards/test_chosen_rewards', c_r, global_step)
			writer.add_scalar('rewards/test_rejected_rewards', r_r, global_step)
			text_table.add_data(f'{epoch}_{idx}', test_loss, text)
			writer.add_text('test_text/texts', f'{epoch}_{idx}\n{text}', global_step)

		global_step += 1

run.log({"training_samples" : text_table})
test_loss, best_test_loss, c_r, r_r = test_it(filename, best_test_loss)
writer.add_scalar('losses/test_loss', test_loss, global_step)
writer.add_scalar('rewards/test_chosen_rewards', c_r, global_step)
writer.add_scalar('rewards/test_rejected_rewards', r_r, global_step)

Epochs: 1/1: 100%|██████████| 20874/20874 [3:09:00<00:00,  1.84it/s, best_test_loss=0.677, loss=0.693, test_loss=0.677]    


KeyboardInterrupt: 