# DPO method

This setup of 64 alpha and rank would require: 34696MiB / 40960MiB VRAM in order to proceed. Batch size 2 and accum_grad 4

In [1]:
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


## Actor

In [2]:
from unsloth import FastLanguageModel
LORA_RANK = 64
LORA_ALPHA = 128



model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "checkpoints/tinyLlama-GSM8K-10epochs", # "unsloth/tinyllama" for 16bit loading
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = False,
)

model = FastLanguageModel.get_peft_model(
    model,
    target_modules=[
        "q_proj",
        "v_proj",
        "k_proj",
        "o_proj",  # attention (self_attn)
        "gate_proj",
        "down_proj",
        "up_proj",  # FFN (mlp)
    ],
    r=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    lora_dropout=0,
    bias="none",
    use_gradient_checkpointing=False,
)


  from .autonotebook import tqdm as notebook_tqdm


==((====))==  Unsloth: Fast Llama patching release 2024.4
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.394 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2. CUDA = 8.0. CUDA Toolkit = 11.8.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Unsloth 2024.4 patched 22 layers with 22 QKV layers, 22 O layers and 22 MLP layers.


## Critic

In [3]:
critic, critic_tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/home/jianingqi/LLMRL/checkpoints/llama3-8b-critic-lora-4-29", # "unsloth/tinyllama" for 16bit loading
    max_seq_length = 2048,
    dtype = None,
    load_in_4bit = False,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)
FastLanguageModel.for_inference(critic) # Enable native 2x faster inference
critic_tokenizer.padding_side = "left" # Padding side for faster inference


==((====))==  Unsloth: Fast Llama patching release 2024.4
   \\   /|    GPU: NVIDIA A100-SXM4-40GB. Max memory: 39.394 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.2.2. CUDA = 8.0. CUDA Toolkit = 11.8.
\        /    Bfloat16 = TRUE. Xformers = 0.0.25.post1. FA = True.
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00,  1.26s/it]
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Unsloth 2024.4 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


## Rollout

In [4]:
from tqdm import tqdm
from transformers.utils import logging
logging.set_verbosity_error()

def generate_answers(input_text, generator, tokenizer, n_answers=2, batch_size=128):    
    all_answers_list = []
    for n in tqdm(range(0, n_answers), desc=" Answer Set", position=0):
        all_answers = []
        for i in tqdm(range(0, len(input_text), batch_size), desc="Answers in Answer Set", position=1, leave=True):
            batch_inputs = input_text[i:i+batch_size]
            batch_inputs = tokenizer(batch_inputs, return_tensors='pt', padding="max_length", truncation=True, max_length=256).to(device)
            outputs = generator.generate(
                **batch_inputs,
                max_new_tokens=256,
                use_cache=True,
                do_sample=True,
                temperature=0.5,
                top_k=40
            )
            answers = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            all_answers.extend(answers)
        
        print(f"Generated {len(all_answers)} answers for set {n}.")
        all_answers_list.append(all_answers)
    
    return all_answers_list

In [5]:
def compute_probabilities(all_answers, critic_tokenizer, critic, batch_size=32, is_llama = True):
    answers_prob = [[] for _ in range(len(all_answers[0]))]
    
    good_token = ' +'
    bad_token = '-'
    step_tag = ' ки'

    candidate_tokens = critic_tokenizer.encode(f"{good_token} {bad_token}")[1:] # [648, 387]
    step_tag_id = critic_tokenizer.encode(f"{step_tag}")[-1] # 12902
    # print(candidate_tokens)
    # print(step_tag_id)

    with torch.no_grad():
        for answers in tqdm(all_answers, desc="Processing rewards", position=0):
            results = []
            response_counts = []
            for answer in answers:
                if '### Response:' in answer:
                    result = answer.split('### Response:')[0]
                    if len(answer.split('### Response:\n')) > 1:
                        responses = answer.split('### Response:\n')[1].split('\n')
                    else:
                        responses = ['']
                    num_responses = len(responses)
                    response_counts.append(num_responses)
                elif '?' in answer:
                    # print(answer)
                    result = answer.split('?')[0] + '?'
                    if len(answer.split('?')) > 1:
                        responses = answer.split('?')[1].split('\n')
                    else:
                        responses = ['']
                    
                    num_responses = len(responses)
                    response_counts.append(num_responses)
                elif '####' in answer:
                    result = answer.split('####')[0]
                    if len(answer.split('####')) > 1:
                        responses = answer.split('####')[1].split('\n')
                    else:
                        responses = ['']
                    responses[0] = '####' + responses[0]
                    num_responses = len(responses)
                    response_counts.append(num_responses)
                else:
                    result = answer
                    responses = ['']
                    num_responses = len(responses)
                    response_counts.append(num_responses)
                    
                     
                for response in responses:
                    result += response + " ки \n"
                results.append(result)
                                
            correct_probabilities = []
            for i in tqdm(range(0, len(results), batch_size), desc="Processing batch",position=1,  leave=True):
                batch_results = results[i:i+batch_size]
                
                inputs = critic_tokenizer(batch_results, padding="max_length", truncation=True, max_length=512, return_tensors="pt").to("cuda")
                logits = critic(**inputs).logits[:,:,candidate_tokens]
                scores = logits.softmax(dim=-1)[:,:,0] 
                step_scores = scores[inputs['input_ids'] == step_tag_id]
                correct_probabilities.extend(step_scores.tolist())
            
            # response_counts = []
            # for answer in answers:
            #     num_responses = len(answer.split('### Response:\n')[1].split('\n'))
            #     response_counts.append(num_responses)
            
            probability_index = 0
            for i, count in enumerate(response_counts):
                answer_probs = correct_probabilities[probability_index:probability_index+count]
                if answer_probs:
                    # answer_prob = min(answer_probs)
                    answer_prob = torch.tensor(answer_probs).prod().item()
                    answers_prob[i].append(answer_prob)
                else:
                    print('len of prob')
                    print(len(correct_probabilities))
                    print('len of responses')
                    print(sum(response_counts))
                    print('There is a length mismatch')
                    print('-----', i)
                    print(answers[i])
                    answers_prob[i].append(0.0)
                probability_index += count
    
    return answers_prob

def select_high_low_probability_answers(all_answers, answers_prob):
    highest_probability_answers = []
    lowest_probability_answers = []
    extracted_answers = [[] for _ in range(len(all_answers[0]))]
    for answers in all_answers:
        for i, answer in enumerate(answers):
            extracted_answers[i].append(answer)
                
    for i, question_answers in enumerate(extracted_answers):
        question_probs = answers_prob[i]
        if question_probs:
            max_prob_index = question_probs.index(max(question_probs))
            highest_probability_answer = question_answers[max_prob_index]
            min_prob_index = question_probs.index(min(question_probs))
            lowest_probability_answer = question_answers[min_prob_index]
        else:
            highest_probability_answer = ""
            lowest_probability_answer = ""
        highest_probability_answers.append(highest_probability_answer)
        lowest_probability_answers.append(lowest_probability_answer)
    return highest_probability_answers, lowest_probability_answers

In [6]:
def rollout_to_DPO_dataset(dataset, model, tokenizer, critic_tokenizer, critic, device = "cuda"):
    model.to(device)
    print('Rolling Out from model')
    with torch.no_grad():
        answers = generate_answers(dataset['prompt'], model, tokenizer, n_answers=2)
    print('Roll out completed')
    print('Starting to compute rewards')
    answers_prob = compute_probabilities(answers, critic_tokenizer, critic)
    highest_probability_answers, lowest_probability_answers = select_high_low_probability_answers(answers, answers_prob)

    # Add the "chosen" column
    epoch_dataset = dataset
    epoch_dataset = epoch_dataset.add_column("chosen", highest_probability_answers)
    # Add the "rejected" column
    epoch_dataset = epoch_dataset.add_column("rejected", lowest_probability_answers)

    # Compute rewards based on answer probabilities
    rewards = []
    for probs in answers_prob:
        if probs:
            max_prob = max(probs)
            min_prob = min(probs)
            rewards.append([max_prob, min_prob])
        else:
            rewards.append([0.0, 0.0])

    return epoch_dataset, rewards

In [7]:
# epoch_dataset, rewards = rollout_to_DPO_dataset(dataset, model, tokenizer, critic_tokenizer, critic)

# Train DPO model

In [8]:
# One must patch the DPO Trainer first!
from unsloth import PatchDPOTrainer
PatchDPOTrainer()

In [9]:
prompt = """
### Input:
{}

### Response:
"""
from datasets import load_dataset

dataset = load_dataset("gsm8k", 'main', split='train')
dataset = dataset.rename_column('question', 'prompt')

dataset = dataset.remove_columns('answer')

In [10]:
from transformers import TrainingArguments, get_scheduler
from trl import DPOTrainer
from torch.optim import AdamW

epochs = 10
base_lr = 4e-6
total_steps = len(dataset) * epochs

optimizer = AdamW(model.parameters(), lr=base_lr)



In [11]:
import wandb
wandb.login()
%env WANDB_PROJECT=LLMRL

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33mjq394[0m ([33mneurorunner[0m). Use [1m`wandb login --relogin`[0m to force relogin


env: WANDB_PROJECT=LLMRL


In [12]:
training_args= TrainingArguments(
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            warmup_ratio=0.1,
            num_train_epochs=1,
            fp16=not torch.cuda.is_bf16_supported(),
            bf16=torch.cuda.is_bf16_supported(),
            logging_steps=1,
            optim="adamw_8bit",
            weight_decay=0.0,
            lr_scheduler_type="constant",  # Set the scheduler type to "constant"
            seed=42,
            output_dir="checkpoints/dpo-tinyllama-5-1",
            save_strategy="epoch",
            save_total_limit=3,
            report_to="wandb",  # Add this line to enable wandb reporting

        )

# Create the learning rate scheduler
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=int(total_steps * 0.1),  # 10% of total steps for warmup
    num_training_steps=total_steps,
)

In [13]:
wandb.init(project="LLMRL", config=training_args)  # Initialize wandb run


In [14]:
from datasets import concatenate_datasets
epoch_datasets = []  # List to store the datasets from each epoch

for epoch in tqdm(range(epochs)):
    FastLanguageModel.for_inference(model)
    epoch_dataset, rewards = rollout_to_DPO_dataset(dataset, model, tokenizer, critic_tokenizer, critic)
    FastLanguageModel.for_training(model)
    
    epoch_datasets.append(epoch_dataset)
    aggregated_dataset = concatenate_datasets(epoch_datasets)
    
    dpo_trainer = DPOTrainer(
        model=model,
        ref_model=None,
        args=training_args,
        beta=0.1,
        train_dataset=aggregated_dataset,
        tokenizer=tokenizer,
        max_length=512,
        max_prompt_length=256,
        optimizers=(optimizer, lr_scheduler),
    )
    
    
    
    # Train the model for one epoch
    train_results = dpo_trainer.train()
    # Log the training loss and other metrics to wandb
    wandb.log({"train/loss": train_results.training_loss}, step=epoch)
    wandb.log({"train/learning_rate": lr_scheduler.get_last_lr()[0]}, step=epoch)

    # Update the learning rate for the next epoch
    lr_scheduler.step()

  0%|          | 0/10 [00:00<?, ?it/s]

Rolling Out from model


Answers in Answer Set: 100%|██████████| 59/59 [10:20<00:00, 10.51s/it]
 Answer Set:  50%|█████     | 1/2 [10:20<10:20, 620.21s/it]

Generated 7473 answers for set 0.


Answers in Answer Set: 100%|██████████| 59/59 [10:20<00:00, 10.51s/it]
 Answer Set: 100%|██████████| 2/2 [20:40<00:00, 620.21s/it]


Generated 7473 answers for set 1.
Roll out completed
Starting to compute rewards


Processing batch: 100%|██████████| 234/234 [05:41<00:00,  1.46s/it]
Processing batch: 100%|██████████| 234/234 [05:40<00:00,  1.46s/it]
Processing rewards: 100%|██████████| 2/2 [11:22<00:00, 341.17s/it]


{'loss': 0.6931, 'grad_norm': 8.089056968688965, 'learning_rate': 5.352602703064364e-10, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -78.43841552734375, 'logps/chosen': -68.17208862304688, 'logits/rejected': -2.997067928314209, 'logits/chosen': -2.474667549133301, 'epoch': 0.0}
{'loss': 0.6931, 'grad_norm': 9.696084022521973, 'learning_rate': 1.0705205406128729e-09, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/accuracies': 0.0, 'rewards/margins': 0.0, 'logps/rejected': -83.40076446533203, 'logps/chosen': -59.46900177001953, 'logits/rejected': -2.865427017211914, 'logits/chosen': -2.6431589126586914, 'epoch': 0.0}
{'loss': 0.6904, 'grad_norm': 9.430974960327148, 'learning_rate': 1.6057808109193094e-09, 'rewards/chosen': -0.0016100883949548006, 'rewards/rejected': -0.007131862919777632, 'rewards/accuracies': 0.375, 'rewards/margins': 0.0055217742919921875, 'logps/rejected': -86.44808959960938, 'logps/cho

 10%|█         | 1/10 [48:50<7:19:36, 2930.75s/it]

{'loss': 0.9112, 'grad_norm': 22.7171573638916, 'learning_rate': 4.999330924662116e-07, 'rewards/chosen': -1.8117804527282715, 'rewards/rejected': -1.9352043867111206, 'rewards/accuracies': 0.625, 'rewards/margins': 0.1234239935874939, 'logps/rejected': -88.09785461425781, 'logps/chosen': -77.86730194091797, 'logits/rejected': -2.7334747314453125, 'logits/chosen': -2.552826166152954, 'epoch': 1.0}
{'train_runtime': 1007.4085, 'train_samples_per_second': 7.418, 'train_steps_per_second': 0.927, 'train_loss': 0.6352110372046587, 'epoch': 1.0}
Rolling Out from model


Answers in Answer Set: 100%|██████████| 59/59 [04:54<00:00,  5.00s/it]
 Answer Set:  50%|█████     | 1/2 [04:54<04:54, 294.91s/it]



Generated 7473 answers for set 0.


Answers in Answer Set: 100%|██████████| 59/59 [04:12<00:00,  4.28s/it]
 Answer Set: 100%|██████████| 2/2 [09:07<00:00, 273.73s/it]


Generated 7473 answers for set 1.
Roll out completed
Starting to compute rewards


Processing batch: 100%|██████████| 234/234 [05:38<00:00,  1.45s/it]
Processing batch: 100%|██████████| 234/234 [05:38<00:00,  1.45s/it]
Processing rewards: 100%|██████████| 2/2 [11:17<00:00, 338.80s/it]
Map: 100%|██████████| 7473/7473 [00:10<00:00, 721.11 examples/s]


{'loss': 0.6636, 'grad_norm': 4.431539058685303, 'learning_rate': 5.010036130068246e-07, 'rewards/chosen': 0.06045839935541153, 'rewards/rejected': -0.0009704576805233955, 'rewards/accuracies': 0.5, 'rewards/margins': 0.06142885982990265, 'logps/rejected': -32.294097900390625, 'logps/chosen': -34.145511627197266, 'logits/rejected': -3.1428115367889404, 'logits/chosen': -3.1892099380493164, 'epoch': 0.0}
{'loss': 0.6843, 'grad_norm': 9.419195175170898, 'learning_rate': 5.015388732771309e-07, 'rewards/chosen': 0.042829252779483795, 'rewards/rejected': 0.01823258399963379, 'rewards/accuracies': 0.375, 'rewards/margins': 0.024596670642495155, 'logps/rejected': -32.82616424560547, 'logps/chosen': -38.43767166137695, 'logits/rejected': -3.1215693950653076, 'logits/chosen': -3.2568976879119873, 'epoch': 0.0}
{'loss': 0.6903, 'grad_norm': 4.6575422286987305, 'learning_rate': 5.020741335474374e-07, 'rewards/chosen': 0.08346738666296005, 'rewards/rejected': 0.07534250617027283, 'rewards/accuraci

 20%|██        | 2/10 [1:26:16<5:37:00, 2527.58s/it]

{'loss': 0.657, 'grad_norm': 8.93879222869873, 'learning_rate': 1.0004014452027298e-06, 'rewards/chosen': -0.017275717109441757, 'rewards/rejected': -0.15321481227874756, 'rewards/accuracies': 0.5, 'rewards/margins': 0.1359390914440155, 'logps/rejected': -40.58965301513672, 'logps/chosen': -38.54426574707031, 'logits/rejected': -2.7932381629943848, 'logits/chosen': -2.81549072265625, 'epoch': 1.0}
{'train_runtime': 1009.546, 'train_samples_per_second': 7.402, 'train_steps_per_second': 0.925, 'train_loss': 0.6608951019015507, 'epoch': 1.0}
Rolling Out from model


Answers in Answer Set: 100%|██████████| 59/59 [03:00<00:00,  3.06s/it]
 Answer Set:  50%|█████     | 1/2 [03:00<03:00, 180.55s/it]

Generated 7473 answers for set 0.


Answers in Answer Set: 100%|██████████| 59/59 [03:26<00:00,  3.50s/it]
 Answer Set: 100%|██████████| 2/2 [06:26<00:00, 193.48s/it]


Generated 7473 answers for set 1.
Roll out completed
Starting to compute rewards


Processing batch: 100%|██████████| 234/234 [05:38<00:00,  1.45s/it]
Processing batch: 100%|██████████| 234/234 [05:38<00:00,  1.45s/it]
Processing rewards: 100%|██████████| 2/2 [11:17<00:00, 338.70s/it]
Map: 100%|██████████| 7473/7473 [00:10<00:00, 743.41 examples/s]


{'loss': 0.6206, 'grad_norm': 9.11595344543457, 'learning_rate': 1.0014719657433427e-06, 'rewards/chosen': 0.5641312599182129, 'rewards/rejected': 0.4016731381416321, 'rewards/accuracies': 0.75, 'rewards/margins': 0.1624581217765808, 'logps/rejected': -35.557796478271484, 'logps/chosen': -38.4548225402832, 'logits/rejected': -3.2089600563049316, 'logits/chosen': -3.2121386528015137, 'epoch': 0.0}
{'loss': 0.6501, 'grad_norm': 9.302193641662598, 'learning_rate': 1.0020072260136491e-06, 'rewards/chosen': 0.3189576268196106, 'rewards/rejected': 0.2286132574081421, 'rewards/accuracies': 0.625, 'rewards/margins': 0.0903443843126297, 'logps/rejected': -31.99930763244629, 'logps/chosen': -31.16376495361328, 'logits/rejected': -3.1708970069885254, 'logits/chosen': -3.1811437606811523, 'epoch': 0.0}
{'loss': 0.7021, 'grad_norm': 5.569589614868164, 'learning_rate': 1.0025424862839556e-06, 'rewards/chosen': 0.17907698452472687, 'rewards/rejected': 0.196123868227005, 'rewards/accuracies': 0.25, 'r

 20%|██        | 2/10 [1:49:36<7:18:25, 3288.18s/it]

{'loss': 0.5465, 'grad_norm': 4.866977691650391, 'learning_rate': 1.1615147865649672e-06, 'rewards/chosen': 0.08681786805391312, 'rewards/rejected': -0.37227028608322144, 'rewards/accuracies': 0.625, 'rewards/margins': 0.45908817648887634, 'logps/rejected': -39.99687194824219, 'logps/chosen': -36.43140411376953, 'logits/rejected': -3.161132335662842, 'logits/chosen': -3.202683448791504, 'epoch': 0.32}





KeyboardInterrupt: 

In [None]:
wandb.finish()  # Finish the wandb run


# Save Model

In [None]:
model.save_pretrained("checkpoints/dpo-tinyllama-5-1") # Local saving
tokenizer.save_pretrained("checkpoints/dpo-tinyllama-5-1")