## Preparing the environment

In [1]:
import numpy as np

import torch
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

from datasets import load_dataset
import transformers

from tqdm import tqdm

import matplotlib.pyplot as plt



device = "cuda" if torch.cuda.is_available() else "cpu"
data_path = "../data"

In [2]:
seed = 42

np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False

## Loading dataset

In [3]:
train = load_dataset("EleutherAI/pile-deduped-pythia-random-sampled", cache_dir=data_path, split="train")
train = train.with_format("torch")

## Training

### Full-rank Training

In [4]:
batch_size = 256
global_batch_size = 512
accumulation_steps = global_batch_size / batch_size

learning_rate = 1e-4
betas = (0.9, 0.95)
eps = 1e-8
gradient_clipping = 1.0
weight_decay = 0.1

warmup_iters = 256

train_iters = 2048

In [5]:
train_loader = iter(DataLoader(train, batch_size=batch_size, shuffle=False))

In [6]:
model = transformers.AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-14m", revision="step1").to(device)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay)

scheduler = transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=warmup_iters, num_training_steps=train_iters)

In [8]:
model.train()

for iter in tqdm(range(train_iters)):
    optimizer.zero_grad()

    if iter > train_iters:
        break
    x = next(train_loader)["Tokens"].to('cuda')

    output = model(input_ids=x, labels=x)
    loss = output.loss
    loss.backward()

    if (iter + 1) % accumulation_steps == 0 or iter + 1 == train_iters: 
        torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
        optimizer.step()
        scheduler.step()
    
    if iter % 128 == 0:
        print(loss.item())

  0%|          | 1/2048 [00:00<17:39,  1.93it/s]

11.03548526763916


  6%|▋         | 131/2048 [00:20<04:09,  7.69it/s]

10.8258056640625


 13%|█▎        | 257/2048 [00:41<06:45,  4.42it/s]

10.467057228088379


 19%|█▉        | 387/2048 [01:01<04:06,  6.75it/s]

10.2770414352417


 25%|██▌       | 513/2048 [01:21<05:21,  4.78it/s]

10.060571670532227


 31%|███▏      | 641/2048 [01:42<05:00,  4.68it/s]

9.806251525878906


 38%|███▊      | 771/2048 [02:01<02:55,  7.29it/s]

9.649765014648438


 44%|████▍     | 899/2048 [02:20<02:25,  7.91it/s]

9.43156909942627


 50%|█████     | 1027/2048 [02:41<02:15,  7.53it/s]

9.315899848937988


 56%|█████▋    | 1155/2048 [03:01<02:02,  7.26it/s]

9.144174575805664


 63%|██████▎   | 1283/2048 [03:20<01:53,  6.77it/s]

9.052902221679688


 69%|██████▉   | 1411/2048 [03:40<01:39,  6.40it/s]

9.064898490905762


 75%|███████▌  | 1537/2048 [04:01<01:44,  4.88it/s]

8.988072395324707


 81%|████████▏ | 1665/2048 [04:22<01:19,  4.80it/s]

9.069371223449707


 88%|████████▊ | 1795/2048 [04:44<00:39,  6.42it/s]

8.974088668823242


 94%|█████████▍| 1921/2048 [05:05<00:26,  4.84it/s]

9.066380500793457


100%|██████████| 2048/2048 [05:26<00:00,  6.27it/s]
