In [1]:
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn as nn 
from torch.nn import functional as F 
import torch 
import transformers.optimization as optim 
# import torch.optim as optim 
from torch.utils.data import DataLoader
from tqdm import trange, tqdm
import matplotlib.pyplot as plt 
from datasets import load_dataset 
from accelerate import Accelerator, DeepSpeedPlugin, accelerator
import pickle as pkl 
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training, PeftModel, PeftConfig, PeftModelForCausalLM, get_peft_config
import pandas as pd
import wandb 
import numpy as np 
import transformers 
import re 

device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device == 'cuda': 
  print(torch.cuda.get_device_name()) 
else:
  print(device) 

block_size = 512


Welcome to bitsandbytes. For bug reports, please submit your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
CUDA SETUP: Loading binary x:\python_environments\AI_310\lib\site-packages\bitsandbytes\libbitsandbytes_cuda116.dll...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)


NVIDIA GeForce RTX 3090


In [2]:
tokenizer = AutoTokenizer.from_pretrained(f"X:/hf_models/pythia-1b-deduped-v0")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
ref_model = AutoModelForCausalLM.from_pretrained(f"models/sft_pythia/sft_best_pythia_1b", torch_dtype=torch.float16, device_map='auto', use_cache=False, pad_token_id=tokenizer.eos_token_id)
model = AutoModelForCausalLM.from_pretrained(f"models/sft_pythia/sft_best_pythia_1b", torch_dtype=torch.float16, device_map='auto', use_cache=False, pad_token_id=tokenizer.eos_token_id)

In [3]:
dataset_1 = load_dataset("Anthropic/hh-rlhf", data_dir='harmless-base')
dataset_2 = load_dataset("Anthropic/hh-rlhf", data_dir='helpful-base')

x_train, y_train = list(), list() 
x_test, y_test = list(), list() 

for dataset in [dataset_1, dataset_2]: 
	for i in tqdm(dataset['train']): 
		if len(i['chosen']) < block_size and len(i['rejected']) < block_size: 
			chosen = re.sub(r'\n\nHuman:', r'<|endoftext|>\n\nHuman:', i['chosen']) + '<|endoftext|>'
			rejected = re.sub(r'\n\nHuman:', r'<|endoftext|>\n\nHuman:', i['rejected']) + '<|endoftext|>'
			x_train.append((chosen, rejected)) 

	for i in tqdm(dataset['test']): 
		if len(i['chosen']) < block_size and len(i['rejected']) < block_size: 
			chosen = re.sub(r'\n\nHuman:', r'<|endoftext|>\n\nHuman:', i['chosen']) + '<|endoftext|>'
			rejected = re.sub(r'\n\nHuman:', r'<|endoftext|>\n\nHuman:', i['rejected']) + '<|endoftext|>'
			x_test.append((chosen, rejected)) 

print(f'Train Data: {len(x_train)}; Test Data: {len(x_test)}')

Found cached dataset json (C:/Users/Asus/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-046a49968e35a6f2/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/2 [00:00<?, ?it/s]

Found cached dataset json (C:/Users/Asus/.cache/huggingface/datasets/Anthropic___json/Anthropic--hh-rlhf-bb7971723b14c46c/0.0.0/fe5dd6ea2639a6df622901539cb550cf8797e5a6b2dd7af1cf934bed8e233e6e)


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 42537/42537 [00:01<00:00, 32548.41it/s]
100%|██████████| 2312/2312 [00:00<00:00, 33503.31it/s]
100%|██████████| 43835/43835 [00:01<00:00, 33964.09it/s]
100%|██████████| 2354/2354 [00:00<00:00, 34361.35it/s]

Train Data: 34109; Test Data: 1788





In [4]:
project_name = ''

wandb.init(
    project='DPO Model', 
    entity='uuzall', 
    sync_tensorboard=True, 
    name=project_name, 
    monitor_gym=True, 
    save_code=True,
)

writer = torch.utils.tensorboard.SummaryWriter(f'runs/{project_name}')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33muuzall[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01693333333338766, max=1.0)…

In [5]:
# chosen, rejected, prompt + chosen, prompt + rejected 
chosen, rejected, = list(), list()
test_chosen, test_rejected = list(), list() 

for c, r in x_train: 
  cutoff = c.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
  chosen.append(c[cutoff:])
  cutoff = r.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
  rejected.append(r[cutoff:])

for c, r in x_test: 
  cutoff = c.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
  test_chosen.append(c[cutoff:])
  cutoff = r.rfind('\n\nAssistant: ') + len('\n\nAssistant: ')
  test_rejected.append(r[cutoff:])

In [6]:
def get_log_proba(concat_inputs, concat_loss_masks, out): 
	labels = concat_inputs.input_ids[:, 1:].clone() 
	logits = out.logits[:, :-1, :]
	loss_mask = concat_loss_masks.attention_mask[:, :-1].clone().to(device)
	per_token_logp = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(-1)).squeeze(-1)
	loss = (loss_mask * per_token_logp).sum(-1)
	return loss

def dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps, beta=0.1): 
	pi_logratios = policy_chosen_logps - policy_rejected_logps
	ref_logratios = reference_chosen_logps - reference_rejected_logps

	logits = pi_logratios - ref_logratios

	losses = -F.logsigmoid(beta * logits) 
	chosen_rewards = beta * (policy_chosen_logps - reference_chosen_logps).detach()
	rejected_rewards = beta * (policy_rejected_logps - reference_rejected_logps).detach()

	return losses, chosen_rewards, rejected_rewards

In [7]:
bs, scale_bs = 64, 8
steps = bs // scale_bs 
train_dl = DataLoader(list(zip(x_train, chosen, rejected)), batch_size=scale_bs, shuffle=True, pin_memory=True)
test_dl = DataLoader(list(zip(x_test, test_chosen, test_rejected)) , batch_size=scale_bs, shuffle=False, pin_memory=True)

optimizer = optim.Adafactor(model.parameters(), scale_parameter=False, relative_step=False, warmup_init=False, lr=5e-8)
scheduler = transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=150)
accelerator = Accelerator(gradient_accumulation_steps=steps)
model, optimizer, train_dl, test_dl, scheduler = accelerator.prepare(model, optimizer, train_dl, test_dl, scheduler) 
test_loss, best_test_loss = 0, 100
n_epochs = 4
global_step = 0 

In [8]:
for epoch in range(n_epochs): 
	model.train()
	ref_model.eval()
	for idx, ((p_chosen, p_rejected), chosen, rejected) in (loop := tqdm(enumerate(train_dl), total=len(train_dl))): 
		concat_inputs = tokenizer((p_chosen + p_rejected), return_tensors='pt', max_length=block_size, padding='longest', truncation=True)
		concat_loss_masks = tokenizer((chosen + rejected), return_tensors='pt', max_length=concat_inputs.input_ids.size(1), padding='max_length', truncation=True)
		out = model(**concat_inputs.to(device))
		with torch.no_grad(): 
			ref_out = ref_model(**concat_inputs.to(device))
		log_proba = get_log_proba(concat_inputs, concat_loss_masks, out)
		ref_log_proba = get_log_proba(concat_inputs, concat_loss_masks, ref_out)

		policy_chosen_logps = log_proba[:len(p_chosen)]
		policy_rejected_logps = log_proba[len(p_chosen):]
		reference_chosen_logps = ref_log_proba[:len(p_chosen)]
		reference_rejected_logps = ref_log_proba[len(p_chosen):]

		loss, chosen_rewards, rejected_rewards = dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
		loss = loss.mean() / steps 

		accelerator.backward(loss) 

		if idx % steps == 0: 
			optimizer.step() 
			model.zero_grad() 
			scheduler.step()
		loop.set_description(f'Epochs: {epoch+1}/{n_epochs}')
		loop.set_postfix(loss=loss.item()*steps, test_loss=test_loss, best_test_loss=best_test_loss)

		writer.add_scalar('charts/learning_rate', optimizer.param_groups[0]['lr'], global_step)
		writer.add_scalar('losses/train_loss', loss.item()*steps, global_step)
		writer.add_scalar('rewards/chosen_rewards', chosen_rewards.mean().item(), global_step)
		writer.add_scalar('rewards/rejected_rewards', rejected_rewards.mean().item(), global_step)
		
		if idx % (scale_bs*10) == 0: 
			model.eval()
			test_loss, c_r, r_r = 0, 0, 0 
			with torch.no_grad(): 
				for (p_chosen, p_rejected), chosen, rejected in test_dl: 
					concat_inputs = tokenizer((p_chosen + p_rejected), return_tensors='pt', max_length=block_size, padding='longest', truncation=True)
					concat_loss_masks = tokenizer((chosen + rejected), return_tensors='pt', max_length=concat_inputs.input_ids.size(1), padding='max_length', truncation=True)
					out = model(**concat_inputs.to(device))
					ref_out = ref_model(**concat_inputs.to(device))
					log_proba = get_log_proba(concat_inputs, concat_loss_masks, out)
					ref_log_proba = get_log_proba(concat_inputs, concat_loss_masks, ref_out)

					policy_chosen_logps = log_proba[:len(p_chosen)]
					policy_rejected_logps = log_proba[len(p_chosen):]
					reference_chosen_logps = ref_log_proba[:len(p_chosen)]
					reference_rejected_logps = ref_log_proba[len(p_chosen):]

					loss, chosen_rewards, rejected_rewards = dpo_loss(policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps)
					test_loss += loss.sum().item()
					c_r += chosen_rewards.sum().item() 
					r_r += rejected_rewards.sum().item() 

				test_loss /= len(x_test) 
				c_r /= len(x_test) 
				r_r /= len(x_test)
			if np.abs(test_loss) < np.abs(best_test_loss): 
				best_test_loss = test_loss 
				accelerator.wait_for_everyone()
				unwrapped_model = accelerator.unwrap_model(model)
				unwrapped_model.save_pretrained(f'models/dpo/pythia_1b_best_performing_5e8', save_function=accelerator.save, state_dict=accelerator.get_state_dict(model))
			model.train()
			writer.add_scalar('losses/test_loss', test_loss, global_step)
			writer.add_scalar('rewards/test_chosen_rewards', c_r, global_step)
			writer.add_scalar('rewards/test_rejected_rewards', r_r, global_step)

		global_step += 1

  0%|          | 0/4264 [00:00<?, ?it/s]

Epochs: 1/4: 100%|██████████| 4264/4264 [1:07:53<00:00,  1.05it/s, best_test_loss=0.693, loss=0.692, test_loss=0.693]  
Epochs: 2/4: 100%|██████████| 4264/4264 [1:07:24<00:00,  1.05it/s, best_test_loss=0.693, loss=0.693, test_loss=0.693]  
Epochs: 3/4: 100%|██████████| 4264/4264 [1:07:20<00:00,  1.06it/s, best_test_loss=0.692, loss=0.692, test_loss=0.692]  
Epochs: 4/4:  26%|██▋       | 1120/4264 [17:54<50:15,  1.04it/s, best_test_loss=0.692, loss=0.687, test_loss=0.692]   


KeyboardInterrupt: 