In [1]:
! nvidia-smi | grep -B 0 "W" 

| N/A   70C    P0              32W /  70W |     41MiB / 15360MiB |      0%      Default |
--
| N/A   68C    P8              20W /  70W |      5MiB / 15360MiB |      0%      Default |
--
| N/A   70C    P0              32W /  70W |      2MiB / 15360MiB |      0%      Default |


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="2"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
import torch
import itertools
from torch import nn
from trl import DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader
from transformers import get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, BitsAndBytesConfig
from tqdm.auto import tqdm
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import sys
sys.path.append('../') # please change this path accordingly
from optim.trcg import TRCG
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

> set random seed

In [4]:
seed = 42
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

> we load model in 4bit, which is actually stored as `torch.uint8` and 
>
> we compute in fp16
>
> Note that, with `prepare_model_for_kbit_training`, it is actually fp32
> for AdaHessian and TRCG to work, fp32 is required.

In [5]:
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="float16",
)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
                                             torch_dtype=torch.float16,
                                             device_map="auto", 
                                             quantization_config=config)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False )

> prep for LoRA model
>
> check http://localhost:1331/edit/optim/notebooks/README.md
> 
> for currently implemented default `target modules` in `peft`

In [6]:
lora_config = LoraConfig(
    r=4, 
    lora_alpha=32, # based on paper - https://arxiv.org/abs/2106.09685
    task_type=TaskType.CAUSAL_LM,
    lora_dropout=0.05
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 147,456 || all params: 124,587,264 || trainable%: 0.11835559692522023


> at this moment, we need to disable `gradient checkpoint` for
> current implementation of AdaHessian and TRCG to work

In [7]:
model.is_gradient_checkpointing, model.quantization_method, model.dtype,\
model.is_loaded_in_4bit,\
model.config.use_cache

(False,
 <QuantizationMethod.BITS_AND_BYTES: 'bitsandbytes'>,
 torch.float32,
 True,
 True)

> set `use_cache` to False

In [8]:
model.config.use_cache = False

> Start processing data

> first is to apply template, we will then use `DataCollatorForCompletionOnlyLM` to mask non-assistant tokens

In [9]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.add_special_tokens({"additional_special_tokens": \
                              ["<|system|>","<|user|>","<|assistant|>","<|end|>"]})

tokenizer.pad_token = tokenizer.eos_token
template_str = """\
{% for message in messages %}\
{% if message["role"] == "user" %}\
{{ "<|endoftext|><|user|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "system" %}\
{{ "<|system|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "assistant" %}\
{{ "<|assistant|>\n" + message["content"] + "<|end|><|endoftext|>\n" }}\
{% endif %}\
{% endfor %}\
"""
tokenizer.chat_template = template_str
tokenizer.special_tokens_map
model.resize_token_embeddings(len(tokenizer))

# collator choice
collator = DataCollatorForCompletionOnlyLM("<|assistant|>", tokenizer=tokenizer)

In [10]:
def pair_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {key: list(itertools.chain(*value)) for key, value in examples.items()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    result = {
        k: [t[i : i + 2] for i in range(0, total_length, 2)]
        for k, t in concatenated_examples.items()
    }
    return result

def convert_messages(example):
    t=[]
    for ei in example["messages"]:
        if ei["role"]=="user":
            t.append("<|system|>Below is a dialogue between a human user and an AI assistant.<|end|>\n")
            t.append("<|endoftext|><|user|>" + ei["content"] + "<|end|>\n")
        elif ei["role"]=="assistant":
            t.append("<|assistant|>" + ei["content"] + "<|end|><|endoftext|>")
    example["messages"]="".join(t)
    return example

def truncation(example):
    
    _dialogue = tokenizer(example["messages"],padding="max_length",max_length=128,truncation=True)
    
    truncate_dialogue = tokenizer.decode(_dialogue["input_ids"])
    ending = "original"
    if "<|assistant|>" in truncate_dialogue and "<|end|><|endoftext|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 50256] # add <|end|><|endoftext|> -- cutoff assistant token
        ending = "assistant"
    if "<|assistant|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 198]   # add <|end|>\n -- no assistant token
        ending = "user"
    
    example["ending"]=ending
    example["messages"] = truncate_dialogue
    return example

def token_map(dataset):
    dataset = dataset.map(pair_texts, batched=True)
    dataset = dataset.map(convert_messages)
    dataset = dataset.map(truncation)
    dataset = dataset.filter(lambda example: example["ending"]!="user")
    dataset = dataset.remove_columns(["ending"])
    return dataset

dataset = load_dataset("sablo/oasst2_curated")
dataset = token_map(dataset)

In [11]:
def tokenize_function(example):
    outputs =  tokenizer(example["messages"], 
                         padding=False, 
                         truncation=False,
                         max_length=128,
                         return_overflowing_tokens=False,
                         return_length=False)
    return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
train_dataset = dataset["train"].map(tokenize_function, 
                                     batched=True, 
                                     remove_columns=dataset["train"].column_names,
                                     num_proc=2,
                                     batch_size=1000)
test_dataset = dataset["test"].map(tokenize_function, 
                                   batched=True, 
                                   remove_columns=dataset["test"].column_names,
                                   num_proc=2,
                                   batch_size=1000)

In [12]:
# for training (i.e., fine-tuning)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collator, pin_memory=True)
# for getting testing perplexity
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collator, pin_memory=True)

> metric function

In [13]:
def evaluate(dataloader):
    model.eval()
    loss = 0
    for batch in dataloader:
        with torch.no_grad():
            outputs = model(**batch)
        loss += outputs.loss.item() * batch["input_ids"].shape[0]
    loss = loss / len(train_dataloader.dataset)
    try:
        ppl = np.exp(loss)
    except OverflowError:
        ppl = float("inf")
    return loss, ppl

> define optimizer

In [14]:
# the hyperparameters are default setup
# please see /optim/trcg.py for details
# it is almost always the case that there 
# is no need to fine-tune any of the hyper-parameters
# ZERO fine-tunining effort, yay!
optimizer = TRCG(model, device=device)

> training loop

In [15]:
# log stats
# we can measure computational cost by simply logging gradient^ computations
#
# ^ for second-order methods, in particular, Hessian-free methods, we use gradient and Hessian-vector
# product computation as a measure of cost
logger = [] # ep, it, loss, ppl, gradient/Hv product, mem

In [16]:
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
# start memory
start_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# progress bar
progress_bar = tqdm(range(num_training_steps))
# accumulative cost
cost = 0
# training loop
for epoch in range(num_epochs):
    # initialize memory stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
    # compute (training) batch loss and (testing) perplexity
    batch_loss, batch_ppl = evaluate(train_dataloader)
    # _, batch_ppl = evaluate(test_dataloader)
    # logger
    logger.append((epoch, 0, batch_loss, batch_ppl, cost, used_memory))
    # print out results
    print(f"epoch: {epoch}, tr_loss: {batch_loss:.2e}, tr_ppl: {batch_ppl:.2e}, cost: {cost}, mem: {used_memory}")
    model.train()
    step_cnt = 0
    for it, minibatch in enumerate(train_dataloader, 1):
        # forward pass
        outputs = model(**minibatch)
        loss = outputs.loss.item()
        # subject to current implementation
        # need to explicitly define lora param
        lora_param = [w for w in model.parameters() if w.requires_grad]
        # explictly compute the gradient with `create_graph = True`
        # note: it is probably better to use `loss.backward(create_graph=True)`
        #       current implementation requires this form
        V = torch.autograd.grad(outputs.loss, lora_param, create_graph=True)
        # compute square norm the gradient, to monitor progress
        # commented out 
        # V_norm = np.sum([torch.sum(vi.data**2).item() for vi in V])**0.5
        
        # temporary fix of constantly shrinking radius
        optimizer.radius *= 2.0
        # optimization step
        _, _, cg_cost = optimizer.step(minibatch, loss, V)
        # accumulative cost
        cost += cg_cost + 1 # cg_cost is # of gradient and Hv computations
        # update progress bar
        progress_bar.update(1)
        
        if it % 10 == 0:
            # re-compute used memory
            used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
            # compute (training) batch loss and (testing) perplexity
            batch_loss, batch_ppl = evaluate(train_dataloader)
            # _, batch_ppl = evaluate(test_dataloader)
            # logger
            logger.append((epoch, it, batch_loss, batch_ppl, cost, used_memory))
            # print out results
            print(f"epoch: {epoch}, iter: {it}, tr_loss: {batch_loss:.2e}, tr_ppl: {batch_ppl:.2e}, cost: {cost}, mem: {used_memory:.2e}")

  0%|          | 0/441 [00:00<?, ?it/s]

epoch: 0, tr_loss: 9.05e+01, tr_ppl: 2.10e+39, cost: 0, mem: 0.0


epoch: 0, iter: 10, tr_loss: 1.13e+01, tr_ppl: 7.88e+04, cost: 33.0, mem: 1.29e+01


epoch: 0, iter: 20, tr_loss: 5.31e+00, tr_ppl: 2.02e+02, cost: 68.0, mem: 1.29e+01


epoch: 0, iter: 30, tr_loss: 4.63e+00, tr_ppl: 1.03e+02, cost: 101.0, mem: 1.29e+01


epoch: 0, iter: 40, tr_loss: 4.64e+00, tr_ppl: 1.03e+02, cost: 134.0, mem: 1.29e+01


epoch: 0, iter: 50, tr_loss: 3.82e+00, tr_ppl: 4.55e+01, cost: 167.0, mem: 1.29e+01


epoch: 0, iter: 60, tr_loss: 3.65e+00, tr_ppl: 3.84e+01, cost: 200.0, mem: 1.29e+01


epoch: 0, iter: 70, tr_loss: 3.29e+00, tr_ppl: 2.68e+01, cost: 234.0, mem: 1.29e+01


epoch: 0, iter: 80, tr_loss: 3.10e+00, tr_ppl: 2.22e+01, cost: 269.0, mem: 1.29e+01


epoch: 0, iter: 90, tr_loss: 3.16e+00, tr_ppl: 2.36e+01, cost: 301.0, mem: 1.29e+01


epoch: 0, iter: 100, tr_loss: 2.93e+00, tr_ppl: 1.88e+01, cost: 334.0, mem: 1.29e+01


epoch: 0, iter: 110, tr_loss: 3.14e+00, tr_ppl: 2.30e+01, cost: 371.0, mem: 1.29e+01


epoch: 0, iter: 120, tr_loss: 2.81e+00, tr_ppl: 1.66e+01, cost: 407.0, mem: 1.29e+01


epoch: 0, iter: 130, tr_loss: 2.83e+00, tr_ppl: 1.69e+01, cost: 444.0, mem: 1.29e+01


epoch: 0, iter: 140, tr_loss: 2.76e+00, tr_ppl: 1.58e+01, cost: 480.0, mem: 1.29e+01


epoch: 0, iter: 150, tr_loss: 2.79e+00, tr_ppl: 1.62e+01, cost: 518.0, mem: 1.29e+01


epoch: 0, iter: 160, tr_loss: 2.75e+00, tr_ppl: 1.57e+01, cost: 557.0, mem: 1.29e+01


epoch: 0, iter: 170, tr_loss: 2.83e+00, tr_ppl: 1.69e+01, cost: 591.0, mem: 1.29e+01


epoch: 0, iter: 180, tr_loss: 2.71e+00, tr_ppl: 1.51e+01, cost: 633.0, mem: 1.29e+01


epoch: 0, iter: 190, tr_loss: 2.75e+00, tr_ppl: 1.57e+01, cost: 674.0, mem: 1.29e+01


epoch: 0, iter: 200, tr_loss: 2.66e+00, tr_ppl: 1.43e+01, cost: 707.0, mem: 1.29e+01


epoch: 0, iter: 210, tr_loss: 2.69e+00, tr_ppl: 1.47e+01, cost: 744.0, mem: 1.29e+01


epoch: 0, iter: 220, tr_loss: 2.66e+00, tr_ppl: 1.43e+01, cost: 780.0, mem: 1.29e+01


epoch: 0, iter: 230, tr_loss: 2.63e+00, tr_ppl: 1.39e+01, cost: 821.0, mem: 1.29e+01


epoch: 0, iter: 240, tr_loss: 2.64e+00, tr_ppl: 1.40e+01, cost: 857.0, mem: 1.29e+01


epoch: 0, iter: 250, tr_loss: 2.62e+00, tr_ppl: 1.37e+01, cost: 893.0, mem: 1.29e+01


epoch: 0, iter: 260, tr_loss: 2.62e+00, tr_ppl: 1.38e+01, cost: 932.0, mem: 1.29e+01


epoch: 0, iter: 270, tr_loss: 2.61e+00, tr_ppl: 1.36e+01, cost: 967.0, mem: 1.29e+01


epoch: 0, iter: 280, tr_loss: 2.74e+00, tr_ppl: 1.55e+01, cost: 1003.0, mem: 1.29e+01


epoch: 0, iter: 290, tr_loss: 2.99e+00, tr_ppl: 2.00e+01, cost: 1044.0, mem: 1.29e+01


epoch: 0, iter: 300, tr_loss: 2.67e+00, tr_ppl: 1.44e+01, cost: 1081.0, mem: 1.29e+01


epoch: 0, iter: 310, tr_loss: 2.68e+00, tr_ppl: 1.46e+01, cost: 1119.0, mem: 1.29e+01


epoch: 0, iter: 320, tr_loss: 2.60e+00, tr_ppl: 1.34e+01, cost: 1158.0, mem: 1.29e+01


epoch: 0, iter: 330, tr_loss: 2.60e+00, tr_ppl: 1.35e+01, cost: 1197.0, mem: 1.29e+01


epoch: 0, iter: 340, tr_loss: 2.61e+00, tr_ppl: 1.36e+01, cost: 1234.0, mem: 1.29e+01


epoch: 0, iter: 350, tr_loss: 2.58e+00, tr_ppl: 1.32e+01, cost: 1268.0, mem: 1.29e+01


epoch: 0, iter: 360, tr_loss: 2.71e+00, tr_ppl: 1.50e+01, cost: 1308.0, mem: 1.29e+01


epoch: 0, iter: 370, tr_loss: 2.65e+00, tr_ppl: 1.42e+01, cost: 1353.0, mem: 1.29e+01


epoch: 0, iter: 380, tr_loss: 2.58e+00, tr_ppl: 1.32e+01, cost: 1389.0, mem: 1.29e+01


epoch: 0, iter: 390, tr_loss: 2.58e+00, tr_ppl: 1.32e+01, cost: 1423.0, mem: 1.29e+01


epoch: 0, iter: 400, tr_loss: 2.59e+00, tr_ppl: 1.33e+01, cost: 1462.0, mem: 1.29e+01


epoch: 0, iter: 410, tr_loss: 2.64e+00, tr_ppl: 1.40e+01, cost: 1502.0, mem: 1.29e+01


epoch: 0, iter: 420, tr_loss: 2.61e+00, tr_ppl: 1.36e+01, cost: 1538.0, mem: 1.29e+01


epoch: 0, iter: 430, tr_loss: 2.58e+00, tr_ppl: 1.33e+01, cost: 1573.0, mem: 1.29e+01


epoch: 0, iter: 440, tr_loss: 2.56e+00, tr_ppl: 1.30e+01, cost: 1608.0, mem: 1.29e+01


In [17]:
with open('trcg_results.pickle', 'wb') as f:
    pickle.dump(logger, f)

In [18]:
# Exit the program with a success status
# also release memory on GPU
exit(0)