## Preparing the environment

In [1]:
import os
import time
import random
import numpy as np
from functools import partial

import torch
from torch.utils.data import DataLoader
from torch.distributed.optim import ZeroRedundancyOptimizer

from datasets import load_dataset, Dataset
import transformers
from peft import LoraConfig, get_peft_model
from accelerate import Accelerator

from tqdm import tqdm

import matplotlib.pyplot as plt


device = "cuda" if torch.cuda.is_available() else "cpu"

data_path = "../data"

In [2]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False

#### Setting constants

In [3]:
context_length = 128

batch_size = 128
num_workers = 8
global_batch_size = 512
accumulation_steps = global_batch_size // batch_size

learning_rate = 3e-4
betas = (0.9, 0.95)
eps = 1e-8
gradient_clipping = 1.0
weight_decay = 0.1

warmup_iters = 256

train_iters = 2048
eval_save_interval = 50
val_iters = 20

lora_rank = 128
lora_dropout = 0.1
lora_alpha = 32
relora_steps = 200
reset_optimizer_on_relora = False
optimizer_magnitude_pruning = 0.8

model_name = "EleutherAI/pythia-14m"
model_revision = "step0"

dataset_path = "allenai/c4"
dataset_name = "realnewslike"

## Loading dataset

In [4]:
dataset = load_dataset(dataset_path, dataset_name)
dataset = dataset.with_format("torch")

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/512 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/76 [00:00<?, ?it/s]

In [5]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, revision=model_revision)

def tokenize(data):
    outputs = tokenizer(
        data["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = list()
    # deleting samples shorter than context_length tokens
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
if os.path.exists(os.path.join(data_path, "train_dataset")):
    train_dataset = Dataset.load_from_disk(os.path.join(data_path, "train_dataset"))
else:
    train_dataset = dataset["train"].map(
        tokenize, batched=True, remove_columns=dataset["train"].column_names
    )
    train_dataset.save_to_disk(os.path.join(data_path, "train_dataset"))


if os.path.exists(os.path.join(data_path, "val_dataset")):
    val_dataset = Dataset.load_from_disk(os.path.join(data_path, "val_dataset"))
else:
    val_dataset = dataset["validation"].map(
        tokenize, batched=True, remove_columns=dataset["validation"].column_names
    )
    val_dataset.save_to_disk(os.path.join(data_path, "val_dataset"))

Loading dataset from disk:   0%|          | 0/58 [00:00<?, ?it/s]

In [7]:
val_dataset = val_dataset.select(range(len(val_dataset) // (batch_size * val_iters) * (batch_size * val_iters)))

In [8]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Training

### ReLoRA Training

In [9]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              worker_init_fn=seed_worker,
                              generator=g)
val_dataloader = DataLoader(val_dataset, 
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            worker_init_fn=seed_worker,
                            generator=g,
                            drop_last=True)

#### Optim reset

In [10]:
@torch.no_grad()
def random_pruning_(tensor, prune_ratio):
    """
    Performs random pruning dimensionality reduction **inplace**.
    Only reduces the inner dimensionality, does not affect the shape of the tensor
    """
    random_pruning_mask = torch.rand_like(tensor) > prune_ratio
    tensor.mul_(random_pruning_mask)


@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
    """
    Performs magnitude pruning dimensionality reduction **inplace**.
    Only reduces the inner dimensionality, does not affect the shape of the tensor
    """
    tensor_magnitude = torch.abs(tensor)
    threshold = torch.quantile(tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio).to(dtype=tensor.dtype)

    mask = tensor_magnitude > threshold
    tensor.mul_(mask.to(dtype=tensor.dtype))

In [11]:
def optimizer_reset(
    optimizer,
    *,
    reset_params: list[torch.nn.Parameter],
    optimizer_state_keys: list[str],
    reset_optimizer_on_relora: bool,
    optimizer_random_pruning: float,
    optimizer_magnitude_pruning: float,
):
    """
        optimizer_state_keys: e.g., ["exp_avg", "exp_avg_sq"]
    """
    n_reset_types = (
        int(bool(reset_optimizer_on_relora))
        + int(bool(optimizer_random_pruning))
        + int(bool(optimizer_magnitude_pruning))
    )
    if n_reset_types != 1:
        # logger.warning(f"Got {reset_optimizer_on_relora=}, {optimizer_random_pruning=}, "
        #                f"{optimizer_magnitude_pruning=}")
        raise ValueError(f"Exactly one of reset_optimizer_on_relora, "
                         f"optimizer_random_pruning, optimizer_magnitude_pruning must be True")

    # pruning_fn has to be inplace to work with ZeroRedundancyOptimizer
    if reset_optimizer_on_relora:
        pruning_fn = partial(random_pruning_, prune_ratio=0.999)
    elif optimizer_random_pruning:
        pruning_fn = partial(random_pruning_, prune_ratio=optimizer_random_pruning)
    elif optimizer_magnitude_pruning:
        pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
    else:
        raise ValueError("Unknown pruning type")
        
    n_zeros = 0
    n_total = 0

    optimizer_state = optimizer.state
    if isinstance(optimizer, ZeroRedundancyOptimizer):
        optimizer_state = optimizer.optim.state

    for p in reset_params:
        param_state = optimizer_state[p]
        if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
            continue
        for key in optimizer_state_keys:
            pruning_fn(param_state[key])  # pruning fn has to be inplace to keep the same keys in the dict
            n_total += param_state[key].numel()
            n_zeros += torch.sum(param_state[key] == 0).item()

    _zeroed = n_zeros / (1e-7 + n_total) * 100
    print(f"Percent of optimizer states zeroed: {_zeroed:.2f}")

### ReLora Training

In [12]:
lora_config = LoraConfig(r=lora_rank, 
                         target_modules=["query_key_value", "dense",
                                         "dense_h_to_4h", "dense_4h_to_h"], 
                         lora_dropout=lora_dropout, 
                         lora_alpha=lora_alpha)

# By default, PEFT initializes LoRA weights with Kaiming-uniform for weight A and zeros for weight B 
# resulting in an identity transform 

In [13]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
                                                          revision=model_revision,
                                                          attn_implementation="flash_attention_2").to(device)

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTNeoXForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTNeoXModel is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch

In [14]:
lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

trainable params: 1,572,864 || all params: 15,640,576 || trainable%: 10.0563


In [15]:
trainable_params = [p for p in lora_model.parameters() if p.requires_grad]
lora_params = [p for n, p in lora_model.named_parameters() if p.requires_grad and "lora_" in n]

In [16]:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]

scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_iters, num_training_steps=train_iters, num_cycles=4)

In [17]:
cycle_length = scheduler.lr_lambdas[0].keywords['num_training_steps'] // scheduler.lr_lambdas[0].keywords['num_cycles']

In [18]:
accelerator = Accelerator(mixed_precision="fp16")

lora_model, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
    lora_model, optimizer, train_dataloader, val_dataloader
)

In [19]:
def evaluate(model, val_dataloader, val_iters):
    model.eval()
    val_losses = list()
    for step, batch in enumerate(val_dataloader):
        with torch.inference_mode():
            outputs = model(batch["input_ids"], labels=batch["input_ids"])
            val_losses.append(outputs.loss.item())
        if step + 1 >= val_iters:
            break
    val_loss = np.mean(val_losses)
    return val_loss

In [20]:
lora_model.train()

train_losses = dict()
val_losses = dict()
last_losses = list()
completed_steps = 0
n_lora_restarts = 0
n_optimizer_resets = 0

for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=int(train_iters * accumulation_steps)
    ):
    output = lora_model(batch["input_ids"], labels=batch["input_ids"])
    loss = output.loss
    last_losses.append(loss.item())
    loss /= accumulation_steps
    accelerator.backward(loss)

    if step % accumulation_steps == 0:
        accelerator.clip_grad_norm_(lora_model.parameters(), gradient_clipping)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
    
    if step % (eval_save_interval * accumulation_steps) == 0:
        train_losses[completed_steps] = np.mean(last_losses)
        val_losses[completed_steps] = evaluate(lora_model, val_dataloader, val_iters)
        print(
            {
                "steps": completed_steps,
                "loss/train": train_losses[completed_steps],
                "loss/val": val_losses[completed_steps],
            }
        )
        last_losses = list()
        lora_model.train()
        accelerator.wait_for_everyone()
    
    if completed_steps >= train_iters:
        train_losses[completed_steps] = np.mean(last_losses)
        val_losses[completed_steps] = evaluate(lora_model, val_dataloader, val_iters)
        print(
            {
                "steps": completed_steps,
                "loss/train": train_losses[completed_steps],
                "loss/val": val_losses[completed_steps],
            }
        )
        last_losses = list()
        lora_model.train()
        accelerator.wait_for_everyone()
        break

    if step % accumulation_steps != 0:
        continue

    can_reset_relora = relora_steps is not None and completed_steps >= relora_steps

    if can_reset_relora and completed_steps % relora_steps == 1:
        _lora_reset_time = time.time()
        print(f"Performing lora reset at update step {completed_steps}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_lora_restarts += 1

        lora_model = lora_model.merge_and_unload()
        lora_model = get_peft_model(lora_model, lora_config)

        lora_model = accelerator.prepare(lora_model)
                
        trainable_params = [p for p in lora_model.parameters() if p.requires_grad]
    
        optimizer.param_groups[0]['params'] = trainable_params
        
        _lora_reset_time = time.time() - _lora_reset_time
        print(f"LoRA reset took {_lora_reset_time:.2f}s")

    can_reset_optimizer = relora_steps is not None and completed_steps >= cycle_length

    if can_reset_optimizer and (completed_steps - 0) % cycle_length == 1:
        # scheduler should provide a new warmup after the reset
        print(f"Performing optimizer reset at update step {completed_steps}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_optimizer_resets += 1

        optimizer_reset(
            optimizer,
            reset_params=lora_params,
            optimizer_state_keys=optimizer_state_keys,
            reset_optimizer_on_relora=reset_optimizer_on_relora,
            optimizer_random_pruning=0.0,
            optimizer_magnitude_pruning=optimizer_magnitude_pruning,
        )

    if can_reset_optimizer and (completed_steps - 0) % cycle_length == 2:
        print(f"First step after optimizer reset lr is {optimizer.param_groups[0]['lr']}")

  0%|          | 15/8192 [00:04<39:46,  3.43it/s] 


KeyboardInterrupt: 