## Preparing the environment

In [1]:
import os
import random
import numpy as np

import torch
from torch.utils.data import DataLoader, Dataset
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss

from datasets import load_dataset
import transformers
from accelerate import Accelerator

from tqdm import tqdm

import matplotlib.pyplot as plt


os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False

#### Setting constants

In [3]:
context_length = 128

batch_size = 128
num_workers = 8
global_batch_size = 512
accumulation_steps = global_batch_size / batch_size

learning_rate = 1e-4
betas = (0.9, 0.95)
eps = 1e-8
gradient_clipping = 1.0
weight_decay = 0.1

warmup_iters = 256

train_iters = 2048

model_name = "EleutherAI/pythia-14m"
model_revision = "step0"

dataset_path = "allenai/c4"
dataset_name = "realnewslike"

## Loading dataset

In [4]:
dataset = load_dataset(dataset_path, dataset_name)
dataset = dataset.with_format("torch")

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/512 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/76 [00:00<?, ?it/s]

In [5]:
dataset["train"] = dataset["train"].select(range(65536))

In [6]:
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, revision=model_revision)



In [7]:
def tokenize(data):
    outputs = tokenizer(
        data["text"],
        truncation=True,
        max_length=context_length,
        return_overflowing_tokens=True,
        return_length=True,
    )
    input_batch = list()
    # deleting samples shorter than context_length tokens
    for length, input_ids in zip(outputs["length"], outputs["input_ids"]):
        if length == context_length:
            input_batch.append(input_ids)
    return {"input_ids": input_batch}

tokenized_dataset = dataset.map(
    tokenize, batched=True, remove_columns=dataset["train"].column_names, num_proc=60
)

tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 264339
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 55063
    })
})

In [8]:
train_dataset = tokenized_dataset["train"]
val_dataset = tokenized_dataset["validation"]

## Training

In [9]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2 ** 32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)

train_dataloader = DataLoader(train_dataset, 
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              worker_init_fn=seed_worker,
                              generator=g)
val_dataloader = DataLoader(val_dataset, 
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            worker_init_fn=seed_worker,
                            generator=g)

In [10]:
model = transformers.AutoModelForCausalLM.from_pretrained(model_name,
                                                          revision=model_revision,
                                                          # attn_implementation="flash_attention_2",
                                                          torch_dtype=torch.float16).to(device)

#### Optim reset

In [11]:
def optimizer_reset(
    optimizer,
    *,
    reset_params: list[torch.nn.Parameter],
    optimizer_state_keys: list[str],
    reset_optimizer_on_relora: bool,
    optimizer_random_pruning: float,
    optimizer_magnitude_pruning: float,
):
    """
        optimizer_state_keys: e.g., ["exp_avg", "exp_avg_sq"]
    """
    n_reset_types = (
        int(bool(reset_optimizer_on_relora))
        + int(bool(optimizer_random_pruning))
        + int(bool(optimizer_magnitude_pruning))
    )
    if n_reset_types != 1:
        logger.warning(f"Got {reset_optimizer_on_relora=}, {optimizer_random_pruning=}, "
                       f"{optimizer_magnitude_pruning=}")
        raise ValueError(f"Exactly one of reset_optimizer_on_relora, "
                         f"optimizer_random_pruning, optimizer_magnitude_pruning must be True")

    # pruning_fn has to be inplace to work with ZeroRedundancyOptimizer
    if reset_optimizer_on_relora:
        logger.info("Resetting optimizer states to zeros")
        # looks like zeroing out breaks dictionary in the optimizer
        # see full error below
        pruning_fn = partial(random_pruning_, prune_ratio=0.999)
    elif optimizer_random_pruning:
        logger.info(f"Performing random pruning of optimizer states. "
                    f"Pruning {optimizer_random_pruning} percent")
        pruning_fn = partial(random_pruning_, prune_ratio=optimizer_random_pruning)
    elif optimizer_magnitude_pruning:
        logger.info(f"Performing magnitude pruning of optimizer states. "
                    f"Pruning {optimizer_magnitude_pruning} percent")
        pruning_fn = partial(magnitude_pruning_, prune_ratio=optimizer_magnitude_pruning)
    else:
        raise ValueError("Unknown pruning type")
        
    n_zeros = 0
    n_total = 0

    optimizer_state = optimizer.state
    if isinstance(optimizer, ZeroRedundancyOptimizer):
        optimizer_state = optimizer.optim.state

    for p in reset_params:
        param_state = optimizer_state[p]
        if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
            continue
        for key in optimizer_state_keys:
            pruning_fn(param_state[key])  # pruning fn has to be inplace to keep the same keys in the dict
            n_total += param_state[key].numel()
            n_zeros += torch.sum(param_state[key] == 0).item()

    _zeroed = n_zeros / (1e-7 + n_total) * 100
    logger.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")

### ReLora Training

In [28]:
from peft import LoraConfig

lora_config = LoraConfig(r=128, 
                         target_modules=["query_key_value", "dense",
                                         "dense_h_to_4h", "dense_4h_to_h"], 
                         lora_dropout=0.1, 
                         lora_alpha=32)

# By default, PEFT initializes LoRA weights with Kaiming-uniform for weight A and zeros for weight B 
# resulting in an identity transform 

In [13]:
from peft import get_peft_model

lora_model = get_peft_model(model, lora_config)
lora_model.print_trainable_parameters()

trainable params: 1,572,864 || all params: 15,640,576 || trainable%: 10.0563


In [29]:
for param in lora_model.parameters():
    if param.requires_grad:
        param.data = param.data.float()

In [14]:
trainable_params = [p for p in lora_model.parameters() if p.requires_grad]
lora_params = [p for n, p in lora_model.named_parameters() if p.requires_grad and "lora_" in n]

In [15]:
# optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)


In [16]:
optimizer = torch.optim.AdamW(trainable_params, lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]

scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_iters, num_training_steps=train_iters, num_cycles=4)

In [17]:
global_step = 0
update_step = 0
tokens_seen = 0
tokens_seen_before = 0
n_lora_restarts = 0
n_optimizer_resets = 0

In [18]:
relora_steps = 1000

In [19]:
cycle_length = scheduler.lr_lambdas[0].keywords['num_training_steps'] // scheduler.lr_lambdas[0].keywords['num_cycles']

In [20]:
def calculate_loss(inputs, logits):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss

In [21]:
# accelerator = Accelerator()

In [22]:
accelerator = Accelerator(mixed_precision="fp16")

lora_model, optimizer, train_dataloader = accelerator.prepare(
    lora_model, optimizer, train_dataloader
)

In [23]:
model.train()
completed_steps = 0

for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=train_iters
    ):
    
    logits = model(batch["input_ids"]).logits
    loss = calculate_loss(batch["input_ids"], logits)
    loss = loss / accumulation_steps
    accelerator.backward(loss)
    if step % accumulation_steps == 0:
        accelerator.clip_grad_norm_(model.parameters(), gradient_clipping)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
    
    if step % 100 == 0:
        accelerator.print(
            {
                "steps": completed_steps,
                "loss/train": loss.item() * accumulation_steps,
            }
        )
        
    can_reset_relora = relora_steps is not None and step // accumulation_steps >= relora_steps

    if can_reset_relora and completed_steps % relora_steps == 1:
        _lora_reset_time = time.time()
        logger.info(f"Performing lora reset at update step {completed_steps}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_lora_restarts += 1

        lora_model = lora_model.merge_and_unload()
        lora_model = get_peft_model(model, lora_config)
        
        for param in lora_model.parameters():
            if param.requires_grad:
                param.data = param.data.float()

        _lora_reset_time = time.time() - _lora_reset_time
        logger.info(f"LoRA reset took {_lora_reset_time:.2f}s")

    can_reset_optimizer = relora_steps is not None and step // accumulation_steps >= cycle_length

    if can_reset_optimizer and (completed_steps - scheduler_start_step) % cycle_length == 1:
        # scheduler should provide a new warmup after the reset
        logger.info(f"Performing optimizer reset at update step {update_step}. Current lr is {optimizer.param_groups[0]['lr']}")
        n_optimizer_resets += 1

        training_utils.optimizer_reset(
            optimizer,
            reset_params=lora_params,
            optimizer_state_keys=optimizer_state_keys,
            reset_optimizer_on_relora=args.reset_optimizer_on_relora,
            optimizer_random_pruning=args.optimizer_random_pruning,
            optimizer_magnitude_pruning=args.optimizer_magnitude_pruning,
        )
    # ##############################

    if can_reset_optimizer and (completed_steps - scheduler_start_step) % cycle_length == 2:
        logger.info(f"First step after optimizer reset lr is {optimizer.param_groups[0]['lr']}")

    lr = optimizer.param_groups[0]["lr"]

  0%|          | 3/2048 [00:02<22:44,  1.50it/s]  


ValueError: Attempting to unscale FP16 gradients.

### Full-rank Training

In [11]:
for param in model.parameters():
    if param.requires_grad:
        param.data = param.data.float()

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)

scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_iters, num_training_steps=train_iters)

In [13]:
def calculate_loss(inputs, logits):
    # Shift so that tokens < n predict n
    shift_labels = inputs[..., 1:].contiguous()
    shift_logits = logits[..., :-1, :].contiguous()
    # Calculate per-token loss
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    return loss

In [14]:
accelerator = Accelerator(mixed_precision="fp16")

model, optimizer, train_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader
)

In [15]:
model.train()
completed_steps = 0

for step, batch in tqdm(
        enumerate(train_dataloader, start=1), total=train_iters
    ):
    
    logits = model(batch["input_ids"]).logits
    loss = calculate_loss(batch["input_ids"], logits)
    loss = loss / accumulation_steps
    accelerator.backward(loss)
    if step % accumulation_steps == 0:
        accelerator.clip_grad_norm_(model.parameters(), gradient_clipping)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        completed_steps += 1
    
    if step % 100 == 0:
        accelerator.print(
            {
                "steps": completed_steps,
                "loss/train": loss.item() * accumulation_steps,
            }
        )

  5%|▍         | 102/2048 [00:10<03:00, 10.79it/s]

{'steps': 25, 'loss/train': 10.98653507232666}


 10%|▉         | 202/2048 [00:20<02:50, 10.83it/s]

{'steps': 50, 'loss/train': 10.87724781036377}


 15%|█▍        | 302/2048 [00:30<02:41, 10.78it/s]

{'steps': 75, 'loss/train': 10.707901954650879}


 20%|█▉        | 402/2048 [00:40<02:32, 10.80it/s]

{'steps': 100, 'loss/train': 10.520821571350098}


 25%|██▍       | 502/2048 [00:50<02:23, 10.78it/s]

{'steps': 125, 'loss/train': 10.398367881774902}


 29%|██▉       | 602/2048 [01:00<02:14, 10.73it/s]

{'steps': 150, 'loss/train': 10.303812026977539}


 34%|███▍      | 702/2048 [01:11<02:05, 10.71it/s]

{'steps': 175, 'loss/train': 10.22122859954834}


 39%|███▉      | 802/2048 [01:21<01:56, 10.70it/s]

{'steps': 200, 'loss/train': 10.120756149291992}


 44%|████▍     | 902/2048 [01:31<01:46, 10.76it/s]

{'steps': 225, 'loss/train': 10.018835067749023}


 49%|████▉     | 1002/2048 [01:41<01:38, 10.67it/s]

{'steps': 250, 'loss/train': 9.94930362701416}


 54%|█████▍    | 1102/2048 [01:51<01:27, 10.76it/s]

{'steps': 275, 'loss/train': 9.759871482849121}


 59%|█████▊    | 1202/2048 [02:01<01:18, 10.82it/s]

{'steps': 300, 'loss/train': 9.670370101928711}


 64%|██████▎   | 1302/2048 [02:11<01:09, 10.70it/s]

{'steps': 325, 'loss/train': 9.555130958557129}


 68%|██████▊   | 1402/2048 [02:21<01:00, 10.76it/s]

{'steps': 350, 'loss/train': 9.479826927185059}


 73%|███████▎  | 1502/2048 [02:31<00:50, 10.74it/s]

{'steps': 375, 'loss/train': 9.386810302734375}


 78%|███████▊  | 1602/2048 [02:41<00:41, 10.75it/s]

{'steps': 400, 'loss/train': 9.324230194091797}


 83%|████████▎ | 1702/2048 [02:51<00:32, 10.75it/s]

{'steps': 425, 'loss/train': 9.19173526763916}


 88%|████████▊ | 1802/2048 [03:01<00:23, 10.69it/s]

{'steps': 450, 'loss/train': 9.118826866149902}


 93%|█████████▎| 1902/2048 [03:12<00:13, 10.73it/s]

{'steps': 475, 'loss/train': 9.068615913391113}


 98%|█████████▊| 2002/2048 [03:22<00:04, 10.79it/s]

{'steps': 500, 'loss/train': 8.984365463256836}


2066it [03:28,  9.90it/s]                          
