In [1]:
! nvidia-smi | grep -B 0 "W" 

| N/A   69C    P0              32W /  70W |      2MiB / 15360MiB |      0%      Default |


--
| N/A   68C    P8              12W /  70W |      5MiB / 15360MiB |      0%      Default |
--
| N/A   69C    P8              12W /  70W |      2MiB / 15360MiB |      0%      Default |


In [2]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
import torch
import itertools
from torch import nn
from trl import DataCollatorForCompletionOnlyLM
from datasets import load_dataset
import numpy as np
from torch.utils.data import DataLoader
from transformers import get_scheduler
from transformers.trainer_pt_utils import get_parameter_names
from transformers import AutoTokenizer, get_scheduler, AutoModelForCausalLM, BitsAndBytesConfig
from tqdm.auto import tqdm
from peft import LoraConfig, TaskType, get_peft_model, prepare_model_for_kbit_training
from datasets import load_dataset
import pickle

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

> set random seed

In [4]:
seed = 42
torch.cuda.manual_seed(seed)
torch.manual_seed(seed)
# np.random.seed(seed)

<torch._C.Generator at 0x7f8669335230>

> we load model in 4bit, which is actually stored as `torch.uint8` and 
>
> we compute in fp16
>
> Note that, with `prepare_model_for_kbit_training`, it is actually fp32
> for AdaHessian and TRCG to work, fp32 is required.

In [5]:
config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype="float16",
)
model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2",
                                             torch_dtype=torch.float16,
                                             device_map="auto", 
                                             quantization_config=config)
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False )

> prep for LoRA model
>
> check http://localhost:1331/edit/optim/notebooks/README.md
> 
> for currently implemented default `target modules` in `peft`

In [6]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=32, # based on paper - https://arxiv.org/abs/2106.09685
    task_type=TaskType.CAUSAL_LM,
    lora_dropout=0.05
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 147,456 || all params: 124,587,264 || trainable%: 0.11835559692522023


> at this moment, we need to disable `gradient checkpoint` for
> current implementation of AdaHessian and TRCG to work
> 
> Here, for fair comparison, turn off GC for AdamW

In [7]:
model.is_gradient_checkpointing, model.quantization_method, model.dtype,\
model.is_loaded_in_4bit,\
model.config.use_cache

(False,
 <QuantizationMethod.BITS_AND_BYTES: 'bitsandbytes'>,
 torch.float32,
 True,
 True)

> set `use_cache` to False

In [8]:
model.config.use_cache = False

> Start processing data

> first is to apply template, we will then use `DataCollatorForCompletionOnlyLM` to mask non-assistant tokens

In [9]:
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium")
tokenizer.add_special_tokens({"additional_special_tokens": \
                              ["<|system|>","<|user|>","<|assistant|>","<|end|>"]})

tokenizer.pad_token = tokenizer.eos_token
template_str = """\
{% for message in messages %}\
{% if message["role"] == "user" %}\
{{ "<|endoftext|><|user|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "system" %}\
{{ "<|system|>\n" + message["content"] + "<|end|>\n" }}\
{% elif message["role"] == "assistant" %}\
{{ "<|assistant|>\n" + message["content"] + "<|end|><|endoftext|>\n" }}\
{% endif %}\
{% endfor %}\
"""
tokenizer.chat_template = template_str
tokenizer.special_tokens_map
model.resize_token_embeddings(len(tokenizer))

# collator choice
collator = DataCollatorForCompletionOnlyLM("<|assistant|>", tokenizer=tokenizer)

In [10]:
def pair_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {key: list(itertools.chain(*value)) for key, value in examples.items()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])

    result = {
        k: [t[i : i + 2] for i in range(0, total_length, 2)]
        for k, t in concatenated_examples.items()
    }
    return result

def convert_messages(example):
    t=[]
    for ei in example["messages"]:
        if ei["role"]=="user":
            t.append("<|system|>Below is a dialogue between a human user and an AI assistant.<|end|>\n")
            t.append("<|endoftext|><|user|>" + ei["content"] + "<|end|>\n")
        elif ei["role"]=="assistant":
            t.append("<|assistant|>" + ei["content"] + "<|end|><|endoftext|>")
    example["messages"]="".join(t)
    return example

def truncation(example):
    
    _dialogue = tokenizer(example["messages"],padding="max_length",max_length=128,truncation=True)
    
    truncate_dialogue = tokenizer.decode(_dialogue["input_ids"])
    ending = "original"
    if "<|assistant|>" in truncate_dialogue and "<|end|><|endoftext|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 50256] # add <|end|><|endoftext|> -- cutoff assistant token
        ending = "assistant"
    if "<|assistant|>" not in truncate_dialogue:
        _dialogue["input_ids"][-2:] = [50259, 198]   # add <|end|>\n -- no assistant token
        ending = "user"
    
    example["ending"]=ending
    example["messages"] = truncate_dialogue
    return example

def token_map(dataset):
    dataset = dataset.map(pair_texts, batched=True)
    dataset = dataset.map(convert_messages)
    dataset = dataset.map(truncation)
    dataset = dataset.filter(lambda example: example["ending"]!="user")
    dataset = dataset.remove_columns(["ending"])
    return dataset

dataset = load_dataset("sablo/oasst2_curated")
dataset = token_map(dataset)

In [11]:
def tokenize_function(example):
    outputs =  tokenizer(example["messages"], 
                         padding=False, 
                         truncation=False,
                         max_length=128,
                         return_overflowing_tokens=False,
                         return_length=False)
    return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
train_dataset = dataset["train"].map(tokenize_function, 
                                     batched=True, 
                                     remove_columns=dataset["train"].column_names,
                                     num_proc=2,
                                     batch_size=1000)
test_dataset = dataset["test"].map(tokenize_function, 
                                   batched=True, 
                                   remove_columns=dataset["test"].column_names,
                                   num_proc=2,
                                   batch_size=1000)

In [12]:
# for training (i.e., fine-tuning)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collator, pin_memory=True)
# for getting testing perplexity
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False, collate_fn=collator, pin_memory=True)

> AdamW

> we need to define parameter group
> this is due to the fact that we will
> only apply weight decay on non-bias params

> we also define algorithm args here

In [13]:
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]

learning_rate = 5e-3 # so far the best choice (note: w/o exhaustive search)
beta1, beta2 = 0.9, 0.999 # most of time, no need to change
epsilon = 1e-8
weight_decay = 1e-2 # so far the best w/o exhaustive search

optimizer_kwargs = {
    "lr": learning_rate,
    "betas": (beta1, beta2),
    "eps": epsilon,
}

optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in model.named_parameters() if (n in decay_parameters and p.requires_grad)
            ],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters() if (n not in decay_parameters and p.requires_grad)
            ],
            "weight_decay": 0.0,
        },
    ]

optimizer = torch.optim.AdamW(optimizer_grouped_parameters, **optimizer_kwargs)

> define scheduler

In [14]:
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
num_warmup_steps = 200 # could be reduced
name = "linear"
scheduler = get_scheduler(name=name,
                          optimizer=optimizer,
                          num_warmup_steps=num_warmup_steps,
                          num_training_steps=num_training_steps
                         )

> metric function

In [15]:
def evaluate(dataloader):
    model.eval()
    loss = 0
    for batch in dataloader:
        with torch.no_grad():
            outputs = model(**batch)
        loss += outputs.loss.item() * batch["input_ids"].shape[0]
    loss = loss / len(train_dataloader.dataset)
    try:
        ppl = np.exp(loss)
    except OverflowError:
        ppl = float("inf")
    return loss, ppl

> training loop

In [16]:
# log stats
# we can measure computational cost by simply logging gradient^ computations
#
# ^ for second-order methods, in particular, Hessian-free methods, we use gradient and Hessian-vector
# product computation as a measure of cost
logger = [] # ep, it, loss, ppl, gradient/Hv product, mem

In [17]:
num_epochs = 1
num_training_steps = num_epochs * len(train_dataloader)
# start memory
start_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# progress bar
progress_bar = tqdm(range(num_training_steps))
# accumulative cost
cost = 0
for epoch in range(num_epochs):
    # initialize memory stats
    used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
    # compute (training) batch loss and (testing) perplexity
    batch_loss, batch_ppl = evaluate(train_dataloader)
    # for this demo, we really don't care about validation or testing
    # please uncomment the following line if needed
    # _, batch_ppl = evaluate(test_dataloader)
    # logger
    logger.append((epoch, 0, batch_loss, batch_ppl, cost, used_memory))
    # print out results
    print(f"epoch: {epoch}, iter: {0}, tr_loss: {batch_loss:.2e}, tr_ppl: {batch_ppl:.2e}, cost: {cost}, mem: {used_memory}")
    model.train()
    step_cnt = 0
    for it, minibatch in enumerate(train_dataloader, 1):
        # forward pass
        outputs = model(**minibatch)
        loss = outputs.loss
        # backward pass
        loss.backward()
        # optimization step
        optimizer.step()
        # accumulative cost 
        cost += 1 # each step requires one gradient computation
        scheduler.step()
        optimizer.zero_grad()
        # update progress bar
        progress_bar.update(1)
        
        if it % 10 == 0:
            # re-compute used memory
            used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3) - start_memory
            # compute (training) batch loss and (testing) perplexity
            batch_loss, batch_ppl = evaluate(train_dataloader)
            _, batch_ppl = evaluate(test_dataloader)
            # logger
            logger.append((epoch, it, batch_loss, batch_ppl, used_memory, cost))
            # print out results
            print(f"epoch: {epoch}, iter: {it}, tr_loss: {batch_loss:.2e}, tr_ppl: {batch_ppl:.2e}, cost: {cost}, mem: {used_memory:.2e}")

  0%|          | 0/441 [00:00<?, ?it/s]

epoch: 0, iter: 0, tr_loss: 9.05e+01, tr_ppl: 2.10e+39, cost: 0, mem: 0.0


epoch: 0, iter: 10, tr_loss: 8.40e+01, tr_ppl: 1.01e+02, cost: 10, mem: 4.29e+00


epoch: 0, iter: 20, tr_loss: 1.68e+01, tr_ppl: 2.45e+00, cost: 20, mem: 8.87e+00


epoch: 0, iter: 30, tr_loss: 6.96e+00, tr_ppl: 1.46e+00, cost: 30, mem: 8.87e+00


epoch: 0, iter: 40, tr_loss: 6.36e+00, tr_ppl: 1.41e+00, cost: 40, mem: 8.87e+00


epoch: 0, iter: 50, tr_loss: 5.23e+00, tr_ppl: 1.33e+00, cost: 50, mem: 8.87e+00


epoch: 0, iter: 60, tr_loss: 4.23e+00, tr_ppl: 1.27e+00, cost: 60, mem: 8.87e+00


epoch: 0, iter: 70, tr_loss: 3.80e+00, tr_ppl: 1.23e+00, cost: 70, mem: 8.87e+00


epoch: 0, iter: 80, tr_loss: 3.50e+00, tr_ppl: 1.21e+00, cost: 80, mem: 8.87e+00


epoch: 0, iter: 90, tr_loss: 3.31e+00, tr_ppl: 1.20e+00, cost: 90, mem: 8.87e+00


epoch: 0, iter: 100, tr_loss: 3.13e+00, tr_ppl: 1.19e+00, cost: 100, mem: 8.87e+00


epoch: 0, iter: 110, tr_loss: 2.98e+00, tr_ppl: 1.18e+00, cost: 110, mem: 8.87e+00


epoch: 0, iter: 120, tr_loss: 2.92e+00, tr_ppl: 1.18e+00, cost: 120, mem: 8.87e+00


epoch: 0, iter: 130, tr_loss: 2.86e+00, tr_ppl: 1.17e+00, cost: 130, mem: 8.87e+00


epoch: 0, iter: 140, tr_loss: 2.81e+00, tr_ppl: 1.17e+00, cost: 140, mem: 8.87e+00


epoch: 0, iter: 150, tr_loss: 2.77e+00, tr_ppl: 1.17e+00, cost: 150, mem: 8.87e+00


epoch: 0, iter: 160, tr_loss: 2.74e+00, tr_ppl: 1.16e+00, cost: 160, mem: 8.87e+00


epoch: 0, iter: 170, tr_loss: 2.71e+00, tr_ppl: 1.16e+00, cost: 170, mem: 8.87e+00


epoch: 0, iter: 180, tr_loss: 2.69e+00, tr_ppl: 1.16e+00, cost: 180, mem: 8.87e+00


epoch: 0, iter: 190, tr_loss: 2.67e+00, tr_ppl: 1.16e+00, cost: 190, mem: 8.87e+00


epoch: 0, iter: 200, tr_loss: 2.70e+00, tr_ppl: 1.16e+00, cost: 200, mem: 8.87e+00


epoch: 0, iter: 210, tr_loss: 2.64e+00, tr_ppl: 1.16e+00, cost: 210, mem: 8.87e+00


epoch: 0, iter: 220, tr_loss: 2.63e+00, tr_ppl: 1.16e+00, cost: 220, mem: 8.87e+00


epoch: 0, iter: 230, tr_loss: 2.60e+00, tr_ppl: 1.16e+00, cost: 230, mem: 8.87e+00


epoch: 0, iter: 240, tr_loss: 2.59e+00, tr_ppl: 1.16e+00, cost: 240, mem: 8.87e+00


epoch: 0, iter: 250, tr_loss: 2.58e+00, tr_ppl: 1.15e+00, cost: 250, mem: 8.87e+00


epoch: 0, iter: 260, tr_loss: 2.58e+00, tr_ppl: 1.15e+00, cost: 260, mem: 8.87e+00


epoch: 0, iter: 270, tr_loss: 2.56e+00, tr_ppl: 1.15e+00, cost: 270, mem: 8.87e+00


epoch: 0, iter: 280, tr_loss: 2.58e+00, tr_ppl: 1.15e+00, cost: 280, mem: 8.87e+00


epoch: 0, iter: 290, tr_loss: 2.58e+00, tr_ppl: 1.15e+00, cost: 290, mem: 8.87e+00


epoch: 0, iter: 300, tr_loss: 2.56e+00, tr_ppl: 1.15e+00, cost: 300, mem: 8.87e+00


epoch: 0, iter: 310, tr_loss: 2.55e+00, tr_ppl: 1.15e+00, cost: 310, mem: 8.87e+00


epoch: 0, iter: 320, tr_loss: 2.54e+00, tr_ppl: 1.15e+00, cost: 320, mem: 8.87e+00


epoch: 0, iter: 330, tr_loss: 2.54e+00, tr_ppl: 1.15e+00, cost: 330, mem: 8.87e+00


epoch: 0, iter: 340, tr_loss: 2.54e+00, tr_ppl: 1.15e+00, cost: 340, mem: 8.87e+00


epoch: 0, iter: 350, tr_loss: 2.53e+00, tr_ppl: 1.15e+00, cost: 350, mem: 8.87e+00


epoch: 0, iter: 360, tr_loss: 2.52e+00, tr_ppl: 1.15e+00, cost: 360, mem: 8.87e+00


epoch: 0, iter: 370, tr_loss: 2.52e+00, tr_ppl: 1.15e+00, cost: 370, mem: 8.87e+00


epoch: 0, iter: 380, tr_loss: 2.52e+00, tr_ppl: 1.15e+00, cost: 380, mem: 8.87e+00


epoch: 0, iter: 390, tr_loss: 2.52e+00, tr_ppl: 1.15e+00, cost: 390, mem: 8.87e+00


epoch: 0, iter: 400, tr_loss: 2.51e+00, tr_ppl: 1.15e+00, cost: 400, mem: 8.87e+00


epoch: 0, iter: 410, tr_loss: 2.51e+00, tr_ppl: 1.15e+00, cost: 410, mem: 8.87e+00


epoch: 0, iter: 420, tr_loss: 2.51e+00, tr_ppl: 1.15e+00, cost: 420, mem: 8.87e+00


epoch: 0, iter: 430, tr_loss: 2.52e+00, tr_ppl: 1.15e+00, cost: 430, mem: 8.87e+00


epoch: 0, iter: 440, tr_loss: 2.51e+00, tr_ppl: 1.15e+00, cost: 440, mem: 8.87e+00


In [18]:
with open('adamw_results.pickle', 'wb') as f:
    pickle.dump(logger, f)

In [19]:
# Exit the program with a success status
# also release memory on GPU
exit(0)