In [1]:
import os, sys
import math
from tqdm import tqdm
from datetime import datetime
import ipdb
from typing import List, Dict, Union
import wandb

import torch
import torch.nn as nn
from torch.nn import functional as F

import transformers
from datasets import load_dataset, load_from_disk

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

torch.cuda.empty_cache()

torch.set_printoptions(threshold=1000)

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
class Parameters():
    def __init__(self):
        # training parameters
        self.batch_size = 1
        self.learning_rate = 6e-5
        self.epochs = 3
        self.lr_warmup_steps = 100
        self.context_length = 1024
        self.alpha = 0.5 # weighting for PRPO odds ratio
        self.prompt_max_length = 512
        self.compile = False
        self.dtype = torch.float16
        self.log_iters = 50

        # hyperparameters
        self.dropout = 0.0
        self.grad_clip = 1.0
        self.weight_decay = 0.0

        # device setup
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # logging
        self.wandb = True
        self.wanadb_project_name = "aligntest"
        self.wandb_project = self.wanadb_project_name
        self.wandb_run_name = f"{self.wanadb_project_name}-run-{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}"
        self.wandb_api_key = "***" # paste a valid API key

    def wanadb_init(self):
        wandb.login(key=self.wandb_api_key)
        wandb.init(project=self.wandb_project, name=self.wandb_run_name)
        

parameters = Parameters()
parameters.wanadb_init()

print(f"Computing on {parameters.device}")



Computing on cuda


In [3]:
dataset_path = "files/data/orpo_dataset"
dataset_name = "mlabonne/orpo-dpo-mix-40k"
tokenizer_path = "files/tokenizers/tok16384"
checkpoint_dir = "files/models/"

tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_path)

with open('chat.dtpl', 'r', encoding='utf-8') as file:
    tokenizer.chat_template = file.read()

tokenizer.pad_token = tokenizer.eos_token

if os.path.exists(dataset_path):
    dataset = load_from_disk(dataset_path)
else:
    print("Filtering and tokenizing dataset")
    dataset = load_dataset(dataset_name, split="all")
    
    # optionally filter out some of the entries (37136 vs 36622)
    dataset.filter(lambda x: x["source"] != "toxic-dpo-v0.2")

    # Filter dataset
    # Eliminate entries longer than 512 (prompt_max_length). This is important
    # because we want the prompt + answer to fit within the context_length
    def filter_dataset(examples: Dict[str, Union[str, List[str]]]) -> bool:
        prompt = tokenizer.apply_chat_template(
            examples["chosen"][:-1],
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        )

        prompt_length = prompt.size(-1)
        if prompt_length < parameters.prompt_max_length:
            return True
        
        return False
    
    # Preprocess and tokenize dataset
    def preprocess_dataset(examples: Dict[str, Union[str, List[str]]]) -> Dict[
        str, Union[str, List[str]]
    ]:
        prompt = [tokenizer.apply_chat_template(
            item[:-1], tokenize=False, 
            add_generation_prompt=True
        ) for item in examples["chosen"]]

        chosen = [tokenizer.apply_chat_template(
            item, tokenize=False
        ) for item in examples["chosen"]]

        rejected = [tokenizer.apply_chat_template(
            item, tokenize=False
        ) for item in examples["rejected"]]

        inputs = tokenizer(
            prompt, max_length=parameters.context_length, 
            padding="max_length", truncation=True, return_tensors="pt"
        )

        pos_labels = tokenizer(
            chosen, max_length=parameters.context_length, 
            padding="max_length", truncation=True, return_tensors="pt"
        )

        neg_labels = tokenizer(
            rejected, max_length=parameters.context_length, 
            padding="max_length", truncation=True, return_tensors="pt"
        )

        inputs["positive_input_ids"] = pos_labels["input_ids"]
        inputs["positive_attention_mask"] = pos_labels["attention_mask"]

        inputs["negative_input_ids"] = neg_labels["input_ids"]
        inputs["negative_attention_mask"] = neg_labels["attention_mask"]

        return inputs

    # exclude propts that are too long
    dataset = dataset.filter(filter_dataset)
    
    dataset = dataset.map(
        preprocess_dataset, batched=True, 
        num_proc=min(32, os.cpu_count()),
        remove_columns=dataset.column_names
    )

    dataset.save_to_disk(dataset_path)

In [4]:
# print(len(dataset[2]["positive_input_ids"]))
# print(dataset[2]["positive_input_ids"])
# tokenizer.decode(dataset[2]["positive_input_ids"])

In [5]:
dataset_split = dataset.shuffle(42).train_test_split(test_size=0.05)
train_data = dataset_split["train"]
val_data = dataset_split["test"]

data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=parameters.batch_size, 
    shuffle=False, collate_fn=data_collator, num_workers=0
)

val_loader = torch.utils.data.DataLoader(
    val_data, batch_size=parameters.batch_size, 
    shuffle=False, collate_fn=data_collator, num_workers=0
)



In [6]:
it = iter(train_loader)
batch = next(it)
print(tokenizer.decode(batch["positive_input_ids"][0]))

<|user|>
When a water tank is $30\%$ full, it contains 27 gallons less than when it is $20\%$ empty. How many gallons of water does the tank hold when it is full?</s> 
<|assistant|>
I want to find the total capacity of the tank, so I will call that C.
Then, when the tank is $30\%$ full, it has $0.3C$ gallons of water, and when it is $20\%$ empty, it has $0.8C$ gallons of water.
The problem says that the difference between these two amounts is 27 gallons, so I can write an equation: $0.8C - 0.3C = 27$.
Simplifying the equation, I get $0.5C = 27$, so $C = 54$.
Therefore, the tank holds 54 gallons of water when it is full.
# Answer

54</s> 
</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></

In [7]:
from files.llm import Llama, ModelArgs

checkpoint = torch.load(os.path.join(checkpoint_dir, "base_model.pt"))
config = checkpoint.pop("config")

model_args = ModelArgs(
    dim=config.hidden_size,
    n_layers=config.num_hidden_layers,
    n_heads=config.num_attention_heads,
    n_kv_heads=config.num_key_value_heads,
    vocab_size=config.vocab_size,
    norm_eps=config.rms_norm_eps,
    rope_theta=config.rope_theta,
    max_seq_len=parameters.context_length,
    dropout=config.attention_dropout,
    hidden_dim=config.intermediate_size,
    attention_bias=config.attention_bias,
    mlp_bias=config.mlp_bias,
)

model = Llama(model_args)
model.load_state_dict(checkpoint)
model = model.to(parameters.dtype)
model = model.to(parameters.device)
model.train()

if parameters.compile:
    print("[INFO] Compiling model")
    model = torch.compile(model)

print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")

  checkpoint = torch.load(os.path.join(checkpoint_dir, "base_model.pt"))


138.431232 M parameters


In [13]:
optimizer = torch.optim.AdamW(
    model.parameters(), lr=parameters.learning_rate, betas=(0.9, 0.95), eps=1e-8, 
    fused=(parameters.device.type == "cuda"), weight_decay=parameters.weight_decay,
)

num_trainings_steps = len(train_loader) * parameters.epochs
print(f"Training for {num_trainings_steps} steps")

def lr_lambda(current_step: int) -> float:
    if current_step < parameters.lr_warmup_steps:
        return float(current_step) / float(max(1, parameters.lr_warmup_steps))
    
    progress = float(current_step - parameters.lr_warmup_steps) / float(
        max(1, num_trainings_steps - parameters.lr_warmup_steps)
    )

    return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1)

Training for 111408 steps


In [15]:
def compute_logprops(prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits):
    pass