In [None]:
! nvidia-smi | grep -B 0 "W" 

In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import torch
import itertools
from torch import nn
from trl import DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader
import bitsandbytes as bnb
from torch.optim import lr_scheduler
from functools import partial
from transformers import get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, get_scheduler, AutoModelForCausalLM,\
DataCollatorWithPadding, DataCollatorForLanguageModeling, BitsAndBytesConfig
from tqdm.auto import tqdm
from accelerate import Accelerator
from peft import LoraConfig, TaskType, get_peft_model, PeftModelForCausalLM, prepare_model_for_kbit_training
from datasets import load_dataset
import sys
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

> set random seed

In [None]:
seed = 42
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

> we load model in 4bit, which is actually stored as `torch.uint8` and 
>
> we compute in fp16
>
> Note that, with `prepare_model_for_kbit_training`, it is actually fp32
> for AdaHessian and TRCG to work, fp32 is required.

In [None]:
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="float16",
)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
                                             torch_dtype=torch.float16,
                                             device_map="auto", 
                                             quantization_config=config)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False )

> prep for LoRA model
>
> check http://localhost:1331/edit/optim/notebooks/README.md
> 
> for currently implemented default `target modules` in `peft`

In [None]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=32, # based on paper - https://arxiv.org/abs/2106.09685
    task_type=TaskType.CAUSAL_LM,
    lora_dropout=0.05
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

> at this moment, we need to disable `gradient checkpoint` for
> current implementation of AdaHessian and TRCG to work
> 
> Here, for fair comparison, turn off GC for AdamW

In [None]:
model.is_gradient_checkpointing, model.quantization_method, model.dtype,\
model.is_loaded_in_4bit,\
model.config.use_cache

> set `use_cache` to False

In [None]:
model.config.use_cache = False

> Start processing data

> first is to apply template, we will then use `DataCollatorForCompletionOnlyLM` to mask non-assistant tokens

In [None]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.add_special_tokens({"additional_special_tokens": \
                              ["<|system|>","<|user|>","<|assistant|>","<|end|>"]})

tokenizer.pad_token = tokenizer.eos_token
template_str = """\
{% for message in messages %}\
{% if message["role"] == "user" %}\
{{ "<|endoftext|><|user|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "system" %}\
{{ "<|system|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "assistant" %}\
{{ "<|assistant|>\n" + message["content"] + "<|end|><|endoftext|>\n" }}\
{% endif %}\
{% endfor %}\
"""
tokenizer.chat_template = template_str
tokenizer.special_tokens_map
model.resize_token_embeddings(len(tokenizer))

# collator choice
collator = DataCollatorForCompletionOnlyLM("<|assistant|>", tokenizer=tokenizer)

In [None]:
def pair_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {key: list(itertools.chain(*value)) for key, value in examples.items()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    result = {
        k: [t[i : i + 2] for i in range(0, total_length, 2)]
        for k, t in concatenated_examples.items()
    }
    return result

def convert_messages(example):
    t=[]
    for ei in example["messages"]:
        if ei["role"]=="user":
            t.append("<|system|>Below is a dialogue between a human user and an AI assistant.<|end|>\n")
            t.append("<|endoftext|><|user|>" + ei["content"] + "<|end|>\n")
        elif ei["role"]=="assistant":
            t.append("<|assistant|>" + ei["content"] + "<|end|><|endoftext|>")
    example["messages"]="".join(t)
    return example

def truncation(example):
    
    _dialogue = tokenizer(example["messages"],padding="max_length",max_length=128,truncation=True)
    
    truncate_dialogue = tokenizer.decode(_dialogue["input_ids"])
    ending = "original"
    if "<|assistant|>" in truncate_dialogue and "<|end|><|endoftext|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 50256] # add <|end|><|endoftext|> -- cutoff assistant token
        ending = "assistant"
    if "<|assistant|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 198]   # add <|end|>\n -- no assistant token
        ending = "user"
    
    example["ending"]=ending
    example["messages"] = truncate_dialogue
    return example

def token_map(dataset):
    dataset = dataset.map(pair_texts, batched=True)
    dataset = dataset.map(convert_messages)
    dataset = dataset.map(truncation)
    dataset = dataset.filter(lambda example: example["ending"]!="user")
    dataset = dataset.remove_columns(["ending"])
    return dataset

dataset = load_dataset("sablo/oasst2_curated")
dataset = token_map(dataset)

In [None]:
def tokenize_function(example):
    outputs =  tokenizer(example["messages"], 
                         padding=False, 
                         truncation=False,
                         max_length=128,
                         return_overflowing_tokens=False,
                         return_length=False)
    return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
train_dataset = dataset["train"].map(tokenize_function, 
                                     batched=True, 
                                     remove_columns=dataset["train"].column_names,
                                     num_proc=2,
                                     batch_size=1000)
test_dataset = dataset["test"].map(tokenize_function, 
                                   batched=True, 
                                   remove_columns=dataset["test"].column_names,
                                   num_proc=2,
                                   batch_size=1000)

In [None]:
# for training (i.e., fine-tuning)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collator, pin_memory=True)
# for getting testing perplexity
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collator, pin_memory=True)

> AdamW

> we need to define parameter group
> this is due to the fact that we will
> only apply weight decay on non-bias params

> we also define algorithm args here

In [None]:
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]

learning_rate = 5e-4 # so far the best choice (note: w/o exhaustive search)
beta1, beta2 = 0.9, 0.999 # most of time, no need to change
epsilon = 1e-8
weight_decay = 1e-2 # so far the best w/o exhaustive search

optimizer_kwargs = {
    "lr": learning_rate,
    "betas": (beta1, beta2),
    "eps": epsilon,
}

optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)
            ],
            "weight_decay": 0.0,
        },
    ]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **optimizer_kwargs)

> define scheduler

In [None]:
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
num_warmup_steps = 10 # could be increased
name = "linear"
scheduler = get_scheduler(name=name,
                          optimizer=optimizer,
                          num_warmup_steps=num_warmup_steps,
                          num_training_steps=num_training_steps
                         )

> metric function

In [None]:
def evaluate(dataloader):
    model.eval()
    loss = 0
    for batch in dataloader:
        with torch.no_grad():
            outputs = model(**batch)
        loss += outputs.loss.item() * batch["input_ids"].shape[0]
    loss = loss / len(train_dataloader.dataset)
    try:
        ppl = np.exp(loss)
    except OverflowError:
        ppl = float("inf")
    return loss, ppl

> training loop

In [None]:
# log stats
logger = [] # ep, it, loss, ppl, mem

In [None]:
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
# start memory
start_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# progress bar
progress_bar = tqdm(range(num_training_steps))
for epoch in range(num_epochs):
    # initialize memory stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
    # compute (training) batch loss and (testing) perplexity
    batch_loss, batch_ppl = evaluate(train_dataloader)
#     _, batch_ppl = evaluate(test_dataloader)
    # logger
    logger.append((epoch, 0, batch_loss, batch_ppl, used_memory))
    # print out results
    print(f"epoch: {epoch}, iter: {0}, tr_loss: {batch_loss:.2e}, te_ppl: {batch_ppl:.2e}, mem: {used_memory}")
    model.train()
    step_cnt = 0
    for it, minibatch in enumerate(train_dataloader, 1):
        # forward pass
        outputs = model(**minibatch)
        loss = outputs.loss
        # backward pass
        loss.backward()
        # optimization step
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        # update progress bar
        progress_bar.update(1)
        
        if it % 10 == 0:
            # re-compute used memory
            used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
            # compute (training) batch loss and (testing) perplexity
            batch_loss, batch_ppl = evaluate(train_dataloader)
#             _, batch_ppl = evaluate(test_dataloader)
             # logger
            logger.append((epoch, it, batch_loss, batch_ppl, used_memory))
            # print out results
            print(f"epoch: {epoch}, iter: {it}, tr_loss: {batch_loss:.2e}, te_ppl: {batch_ppl:.2e}, mem: {used_memory:.2e}")

In [None]:
# Exit the program with a success status
# also release memory on GPU
exit(0)