# imports

In [2]:
import sys
sys.path.append('../src/')
import math
import os
import glob
import random
from functools import partial

import numpy as np
import torch
import torch.nn.functional as F

from contextlib import nullcontext
import matplotlib.pyplot as plt
from sentencepiece import SentencePieceProcessor
from data_loader import *
from utils import *
from model import *

In [17]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device_type = 'cuda' if torch.cuda.is_available() else 'cpu'


# paths

In [4]:
DATA_CACHE_DIR = '../instruct_dataset/'

In [5]:
out_dir = '../models/'
os.makedirs(out_dir, exist_ok=True)

# tokenizer

In [6]:
tokenizer = SentencePieceProcessor('../data/tok4096.model')

In [7]:
vocab_size = tokenizer.vocab_size()

# training

#### mixed precision settings

In [18]:
import torch
from contextlib import nullcontext

dtype = 'bfloat16'
torch.manual_seed(1337)

# Check if CUDA is available
if torch.cuda.is_available():
    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
    device_type = "cuda"
    
    # Check if the current CUDA device supports bfloat16
    if dtype == 'bfloat16' and not torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
        print("Current CUDA device does not support bfloat16. Switching to float32.")
        dtype = 'float32'
else:
    device_type = "cpu"
    if dtype == 'bfloat16':
        print("bfloat16 is not supported on CPU. Switching to float32.")
        dtype = 'float32'

ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]

ctx = (
    nullcontext()
    if device_type == "cpu"
    else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
)


bfloat16 is not supported on CPU. Switching to float32.


In [27]:

scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float32"))




#### model

In [11]:
dim = 288 #288 #
n_layers =  6
n_heads =  6
n_kv_heads = n_heads
multiple_of = 32
dropout = 0.0
max_seq_len = 350

In [12]:
model_args = ModelArgs(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_heads,
    vocab_size=32000,
    multiple_of=multiple_of,
    max_seq_len=max_seq_len,
    dropout=dropout,
) 

In [20]:
model = Transformer(model_args)
model.to(device);
print(f'Number of parameters: {sum(p.nelement() for p in model.parameters())}')

Number of parameters: 15191712


#### data

In [21]:
batch_size = 64

wanted_batch_size = 4 * 128
gradient_accumulation_steps = wanted_batch_size // batch_size

print(f'Wanted batch_size: {wanted_batch_size}, gradient accumulation steps: {gradient_accumulation_steps}, batch_size: {batch_size}')

Wanted batch_size: 512, gradient accumulation steps: 8, batch_size: 64


In [22]:
iter_batches = partial(
    iter_batch_func,
    device=device,
    batch_size=batch_size,
    max_seq_len=max_seq_len,
    data_cache_dir=DATA_CACHE_DIR
)

#### optimizer

In [24]:
learning_rate = 5e-4
optimizer = get_optimizer(
    model=model,
    device_type=device_type,
    learning_rate=learning_rate,  # max learning rate
    weight_decay = 1e-1,
    beta1 = 0.9,
    beta2 = 0.95,
)

num decayed parameter tensors: 43, with 15,187,968 parameters
num non-decayed parameter tensors: 13, with 3,744 parameters


## training loop

In [25]:
max_iters = 25000
eval_iters = 100
best_val_loss = 1e9
grad_clip = 1

In [26]:
eval_prompt = 'Write a story. In the story, try to use the verb "eat", the noun "clock" and the adjective "clever". The story has the following features: the story should contain at least one dialogue. Possible story:'

In [28]:
iter_num = 0

train_batch_iter = iter_batches(split='train')
X, Y = next(train_batch_iter)

while True:
    lr = get_lr(iter_num, max_iters=max_iters) 
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    if iter_num % eval_iters == 0 :
        losses = estimate_loss(
            model=model,
            iter_batches=iter_batches,
            eval_iters=eval_iters,
            ctx=ctx
        )
        print(f"step {iter_num}: lr {lr}, train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
        if losses["val"] < best_val_loss:
            best_val_loss = losses["val"]
            if iter_num > 0:
                save_checkpoint(
                    model=model,
                    optimizer=optimizer,
                    model_args=model_args,
                    iter_num=iter_num,
                    out_dir=out_dir
                )
                _, paragraph = generate_paragraph(
                    model, 
                    prompt=eval_prompt,
                    tokenizer=tokenizer,
                    device=device
                )
                print(paragraph)

    for micro_step in range(gradient_accumulation_steps):
        with ctx:
            logits = model(X)
            loss = compute_loss(logits, Y)
            loss = loss / gradient_accumulation_steps
        X, Y = next(train_batch_iter)
        scaler.scale(loss).backward()
     
    if grad_clip != 0.0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad(set_to_none=True)

   
    iter_num += 1
    if iter_num > max_iters:
        break


step 0: lr 0.0, train loss 10.4069, val loss 10.4070


KeyboardInterrupt: 