## Preparing the environment

In [1]:
import os
import sys
import time
import random
import numpy as np

import torch
from torch.utils.data import DataLoader

from datasets import load_dataset, Dataset
import transformers
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator

from tqdm import tqdm

import matplotlib.pyplot as plt

sys.path.append("..")
from utils.relora import optimizer_reset, get_cosine_schedule_with_multiple_warmups


device = "cuda" if torch.cuda.is_available() else "cpu"

data_path = "../data"
models_path = "../models"

In [2]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False

#### Setting constants

In [3]:
context_length = 128

batch_size = 128
num_workers = 8
global_batch_size = 512
accumulation_steps = global_batch_size // batch_size

learning_rate = 3e-4
betas = (0.9, 0.95)
eps = 1e-8
gradient_clipping = 1.0
weight_decay = 0.1

adjust_warmup_iters = 250  # from warm model
first_warmup_iters = 100
restart_warmup_iters = 50
min_lr_ratio = 0.001

adjust_train_iters = 500  # from warm model
train_iters = 2000  # train_iters including adjust_train_iters
eval_save_interval = 50
val_iters = 20

lora_rank = 128
lora_dropout = 0.1
lora_alpha = 32
relora_steps = 500
reset_optimizer_on_relora = False
optimizer_magnitude_pruning = 0.8

model_name = os.path.join(models_path, "pythia-14m_500_of_2000")
tokenizer_name = "EleutherAI/pythia-14m"
tokenizer_revision = "step0"

dataset_path = "allenai/c4"
dataset_name = "realnewslike"

## Loading dataset

In [4]:
dataset = load_dataset(dataset_path, dataset_name)
dataset = dataset.with_format("torch")

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/512 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/76 [00:00<?, ?it/s]

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name, revision=tokenizer_revision)

def tokenize(data):
    outputs = tokenizer(
        data["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = list()
    # deleting samples shorter than context_length tokens
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
if os.path.exists(os.path.join(data_path, "train_dataset")):
    train_dataset = Dataset.load_from_disk(os.path.join(data_path, "train_dataset"))
else:
    train_dataset = dataset["train"].map(
        tokenize, batched=True, remove_columns=dataset["train"].column_names
    )
    train_dataset.save_to_disk(os.path.join(data_path, "train_dataset"))


if os.path.exists(os.path.join(data_path, "val_dataset")):
    val_dataset = Dataset.load_from_disk(os.path.join(data_path, "val_dataset"))
else:
    val_dataset = dataset["validation"].map(
        tokenize, batched=True, remove_columns=dataset["validation"].column_names
    )
    val_dataset.save_to_disk(os.path.join(data_path, "val_dataset"))

Loading dataset from disk:   0%|          | 0/58 [00:00<?, ?it/s]

In [7]:
val_dataset = val_dataset.select(range(len(val_dataset) // (batch_size * val_iters) * (batch_size * val_iters)))

In [8]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Training

### ReLoRA Training

In [9]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              worker_init_fn=seed_worker,
                              generator=g)
val_dataloader = DataLoader(val_dataset, 
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            worker_init_fn=seed_worker,
                            generator=g,
                            drop_last=True)

### ReLora Training

In [10]:
lora_config = LoraConfig(r=lora_rank, 
                         target_modules=["query_key_value", "dense",
                                         "dense_h_to_4h", "dense_4h_to_h"], 
                         lora_dropout=lora_dropout, 
                         lora_alpha=lora_alpha)

# By default, PEFT initializes LoRA weights with Kaiming-uniform for weight A and zeros for weight B 
# resulting in an identity transform 

In [11]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
                                                          attn_implementation="flash_attention_2").to(device)

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTNeoXForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTNeoXModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch

In [12]:
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

trainable params: 1,572,864 || all params: 15,640,576 || trainable%: 10.0563


In [13]:
trainable_params = [p for p in lora_model.parameters() if p.requires_grad]
lora_params = [p for n, p in lora_model.named_parameters() if p.requires_grad and "lora_" in n]

In [14]:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]

scheduler = get_cosine_schedule_with_multiple_warmups(
    optimizer,
    num_training_steps=train_iters,
    first_warmup_steps=first_warmup_iters,
    restart_warmup_steps=restart_warmup_iters,
    restart_every=relora_steps,
    adjust_step=adjust_train_iters,
    adjust_warmup_iters=adjust_train_iters,
    min_lr_ratio=min_lr_ratio,
)

In [15]:
accelerator = Accelerator(mixed_precision="fp16")

lora_model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
    lora_model, optimizer, train_dataloader, val_dataloader
)

In [16]:
def evaluate(model, val_dataloader, val_iters):
    model.eval()
    val_losses = list()
    for step, batch in enumerate(val_dataloader):
        with torch.inference_mode():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])
            val_losses.append(outputs.loss.item())
        if step + 1 >= val_iters:
            break
    val_loss = np.mean(val_losses)
    return val_loss

In [17]:
lora_model.train()

train_losses = dict()
val_losses = dict()
last_losses = list()
completed_steps = adjust_train_iters
n_lora_restarts = 0
n_optimizer_resets = 0

for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=int((train_iters - adjust_train_iters) * accumulation_steps)
    ):
    output = lora_model(batch["input_ids"], labels=batch["input_ids"])
    loss = output.loss
    last_losses.append(loss.item())
    loss /= accumulation_steps
    accelerator.backward(loss)

    if step % accumulation_steps == 0:
        accelerator.clip_grad_norm_(lora_model.parameters(), gradient_clipping)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
    
    if step % (eval_save_interval * accumulation_steps) == 0:
        train_losses[completed_steps] = np.mean(last_losses)
        val_losses[completed_steps] = evaluate(lora_model, val_dataloader, val_iters)
        print(
            {
                "steps": completed_steps,
                "loss/train": train_losses[completed_steps],
                "loss/val": val_losses[completed_steps],
            }
        )
        last_losses = list()
        lora_model.train()
        accelerator.wait_for_everyone()
    
    if completed_steps >= train_iters:
        accelerator.wait_for_everyone()
        break

    if step % accumulation_steps != 0:
        continue

    can_reset_relora = relora_steps is not None and completed_steps >= relora_steps

    if can_reset_relora and completed_steps % relora_steps == 1:
        _lora_reset_time = time.time()
        print(f"Performing lora reset at update step {completed_steps}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_lora_restarts += 1

        lora_model = lora_model.merge_and_unload()
        lora_model = get_peft_model(lora_model, lora_config)

        lora_model = accelerator.prepare(lora_model)
                
        trainable_params = [p for p in lora_model.parameters() if p.requires_grad]
    
        optimizer.param_groups[0]['params'] = trainable_params
        
        _lora_reset_time = time.time() - _lora_reset_time
        print(f"LoRA reset took {_lora_reset_time:.2f}s")

        # scheduler should provide a new warmup after the reset
        print(f"Performing optimizer reset at update step {completed_steps}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_optimizer_resets += 1

        optimizer_reset(
            optimizer,
            reset_params=lora_params,
            optimizer_state_keys=optimizer_state_keys,
            reset_optimizer_on_relora=reset_optimizer_on_relora,
            optimizer_random_pruning=0.0,
            optimizer_magnitude_pruning=optimizer_magnitude_pruning,
        )

  3%|▎         | 202/6000 [00:22<21:10,  4.56it/s]

{'steps': 50, 'loss/train': 6.148068511486054, 'loss/val': 6.164088487625122}


  7%|▋         | 402/6000 [00:40<19:24,  4.81it/s]

{'steps': 100, 'loss/train': 6.143584821224213, 'loss/val': 6.149805474281311}


 10%|█         | 602/6000 [00:59<19:29,  4.62it/s]

{'steps': 150, 'loss/train': 6.138176422119141, 'loss/val': 6.143473505973816}


 13%|█▎        | 802/6000 [01:18<18:43,  4.63it/s]

{'steps': 200, 'loss/train': 6.136109302043915, 'loss/val': 6.15586485862732}


 17%|█▋        | 1002/6000 [01:36<17:40,  4.71it/s]

{'steps': 250, 'loss/train': 6.12351722240448, 'loss/val': 6.134815359115601}


 20%|██        | 1202/6000 [01:55<17:08,  4.66it/s]

{'steps': 300, 'loss/train': 6.112013220787048, 'loss/val': 6.130725860595703}


 23%|██▎       | 1402/6000 [02:14<16:20,  4.69it/s]

{'steps': 350, 'loss/train': 6.11193799495697, 'loss/val': 6.12102439403534}


 27%|██▋       | 1602/6000 [02:32<15:32,  4.72it/s]

{'steps': 400, 'loss/train': 6.098917257785797, 'loss/val': 6.102164483070373}


 30%|███       | 1802/6000 [02:51<14:57,  4.68it/s]

{'steps': 450, 'loss/train': 6.096788954734802, 'loss/val': 6.129397511482239}


 33%|███▎      | 2002/6000 [03:10<14:26,  4.61it/s]

{'steps': 500, 'loss/train': 6.08454080581665, 'loss/val': 6.141530919075012}


 33%|███▎      | 2004/6000 [03:10<13:38,  4.88it/s]

Performing lora reset at update step 501. Current lr is 4.2219897192981734e-06
LoRA reset took 0.05s
Performing optimizer reset at update step 501. Current lr is 4.2219897192981734e-06
Percent of optimizer states zeroed: 80.00


 37%|███▋      | 2202/6000 [03:29<13:43,  4.61it/s]

{'steps': 550, 'loss/train': 6.111590788364411, 'loss/val': 6.104821133613586}


 40%|████      | 2402/6000 [03:48<13:20,  4.49it/s]

{'steps': 600, 'loss/train': 6.108436465263367, 'loss/val': 6.1029904842376705}


 43%|████▎     | 2602/6000 [04:07<12:15,  4.62it/s]

{'steps': 650, 'loss/train': 6.100020468235016, 'loss/val': 6.126300239562989}


 47%|████▋     | 2802/6000 [04:26<11:21,  4.69it/s]

{'steps': 700, 'loss/train': 6.103038260936737, 'loss/val': 6.105225682258606}


 50%|█████     | 3002/6000 [04:45<10:58,  4.56it/s]

{'steps': 750, 'loss/train': 6.099106926918029, 'loss/val': 6.110839772224426}


 53%|█████▎    | 3202/6000 [05:04<10:12,  4.57it/s]

{'steps': 800, 'loss/train': 6.1043798756599426, 'loss/val': 6.096076250076294}


 57%|█████▋    | 3402/6000 [05:23<09:27,  4.58it/s]

{'steps': 850, 'loss/train': 6.09590485572815, 'loss/val': 6.107667398452759}


 60%|██████    | 3602/6000 [05:42<08:32,  4.68it/s]

{'steps': 900, 'loss/train': 6.091790170669555, 'loss/val': 6.1140090227127075}


 63%|██████▎   | 3802/6000 [06:01<07:57,  4.60it/s]

{'steps': 950, 'loss/train': 6.095734131336212, 'loss/val': 6.111231660842895}


 67%|██████▋   | 4002/6000 [06:20<07:18,  4.55it/s]

{'steps': 1000, 'loss/train': 6.093902318477631, 'loss/val': 6.103014087677002}


 67%|██████▋   | 4006/6000 [06:20<05:12,  6.39it/s]

Performing lora reset at update step 1001. Current lr is 1.2414075988794581e-06
LoRA reset took 0.03s
Performing optimizer reset at update step 1001. Current lr is 1.2414075988794581e-06
Percent of optimizer states zeroed: 80.00


 70%|███████   | 4202/6000 [06:38<06:30,  4.61it/s]

{'steps': 1050, 'loss/train': 6.096440486907959, 'loss/val': 6.122192335128784}


 72%|███████▏  | 4290/6000 [06:46<02:24, 11.85it/s]