In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import random
import time
import math
import tiktoken
import inspect
import os


from torch.distributed import init_process_group, destroy_process_group
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist

In [2]:
enc = tiktoken.get_encoding("gpt2")

ddp = int(os.environ.get('RANK', -1)) != -1 
if ddp:
    init_process_group(backend="nccl")
    ddp_rank = int(os.environ['RANK'])
    ddp_local_rank = int(os.environ['LOCAL_RANK'])
    ddp_world_size = int(os.environ['WORLD_SIZE'])
    device = f'cuda:{ddp_local_rank}'
    torch.cuda.set_device(device)
    master_process = ddp_rank == 0
else:
    ddp_rank = 0
    ddp_local_rank = 0
    ddp_world_size = 1
    master_process = True
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"device: {device}")
    

torch.manual_seed(1337)
if device == 'cuda':
    torch.cuda.manual_seed(1337)

def get_lr(epoch):
    if epoch < warmup_lr_steps:
        return (max_lr * (epoch+1)/warmup_lr_steps)
    if epoch > epochs:
        return min_lr
    loc = (epoch - warmup_lr_steps)/(epochs - warmup_lr_steps)
    coef = 0.5 * (1.0 + math.cos(math.pi * loc))
    return min_lr + coef * (max_lr - min_lr)

device: cuda


In [8]:
vocab_size = 50304 #50257
batch_size = 2**19
mini_batches = 8
time_stamps = 512
context_len = 1024

data_dir = "edu_fineweb10B"
log_dir = "log"
checkpoints_frequency = 2000
log_file = os.path.join(log_dir, f"log.txt")
val_log_file = os.path.join(log_dir, f"val_log.txt")

assert batch_size % (mini_batches * time_stamps * ddp_world_size) == 0, "batch_size is not devided by B and T and number_of_gpus"
mini_epochs = int(batch_size / (mini_batches * time_stamps * ddp_world_size)) #number of mini-batches to get 0.5M batch

def load_tokens(filename):
    npt = np.load(filename)
    npt = npt.astype(np.int32) # added after video
    ptt = torch.tensor(npt, dtype=torch.long)
    return ptt

class DataLoader():
    def __init__(self, B, T, cur_process, num_processes, data_dir, split):
        self.B = B
        self.T = T
        self.cur_process = cur_process
        self.cur_shard = 0
        self.num_processes = num_processes
        self.data_dir = data_dir

        shards = os.listdir(self.data_dir)
        shards = [s for s in shards if split in s]
        shards = sorted(shards)
        shards = [os.path.join(self.data_dir, s) for s in shards]
        self.shards = shards

        self.tokens = load_tokens(self.shards[self.cur_shard])
        
        self.current_step = cur_process * B * T

        print(f"loaded ~{len(self.tokens)*len(self.shards)} tokens")


    def reset(self):
        self.cur_shard = 0
        self.tokens = load_tokens(self.shards[self.current_shard])
        self.current_position = self.B * self.T * self.cur_process
        
    def next_batch(self):
        B, T = self.B, self.T
        
        self.current_step += B * T * self.num_processes
        tokens = self.tokens[self.current_step:self.current_step+B*T+1]
        x = (tokens[:-1]).view(B, T)
        y = (tokens[1:]).view(B, T)
        if (self.current_step+B*T* self.num_processes + B*T+1)  > len(self.tokens):
            self.cur_shard = (self.cur_shard+1) % len(self.shards)
            self.tokens = load_tokens(self.shards[self.cur_shard])
            self.current_step = self.cur_process * B * T
        return x, y

In [9]:

# Help functions
def repeat_kv(vct, n_times):
    bsz, cl, num_kv_heads, dm = vct.shape
    if n_times == 1:
        return vct
    else:
        return (
            vct[:, :, :, None, :]
            .expand(bsz, cl, num_kv_heads, n_times, dm)
            .reshape(bsz, cl, num_kv_heads * n_times, dm)
        )


# https://github.com/hkproj/pytorch-llama/blob/main/model.py
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    theta_numerator = torch.arange(0, head_dim, 2).float()
    theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2)
    m = torch.arange(seq_len, device=device)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_pos(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_complex
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)


class RMSNorm(torch.nn.Module):
    def __init__(self, dim, eps : 1e-6):
        super().__init__()
        self.g = nn.Parameter(torch.ones(dim))
        self.eps = eps
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.g * self._norm(x.float()).type_as(x)


class SelfAttention(nn.Module):
    def __init__(self, num_dim, num_heads, num_kv_heads, batch_size, context_len, use_cache=False, device=device):
        super().__init__()

        self.use_cache = use_cache

        self.num_q_heads = num_heads
        self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

        self.num_heads = num_heads
        self.num_rep = self.num_q_heads // self.num_kv_heads
        self.head_dim = num_dim // self.num_q_heads

        self.wq = nn.Linear(num_dim, num_dim, bias=False)
        self.wk = nn.Linear(num_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(num_dim, self.num_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(num_dim, num_dim, bias=False)

        self.cache_k = torch.zeros(
            (
                batch_size,
                context_len,
                self.num_kv_heads,
                self.head_dim
            ), device=device
        )

        self.cache_v = torch.zeros(
            (
                batch_size,
                context_len,
                self.num_kv_heads,
                self.head_dim
            ), device=device
        )

    def forward(self, x, freqs_complex, start_pos = 0):
        c_bsz, c_cl, c_dm = x.shape # c_cl = 1
    
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)

        q = q.view(c_bsz, c_cl, self.num_q_heads, self.head_dim) # B, T, qh, hs
        k = k.view(c_bsz, c_cl, self.num_kv_heads, self.head_dim) # B, T, kh, hs
        v = v.view(c_bsz, c_cl, self.num_kv_heads, self.head_dim) # B, T, vh, hs

        queries = apply_rotary_pos(q, freqs_complex, device=x.device)
        keys = apply_rotary_pos(k, freqs_complex, device=x.device)

        if self.use_cache:
            self.cache_k[:c_bsz, start_pos:start_pos+c_cl] = keys
            self.cache_v[:c_bsz, start_pos:start_pos+c_cl] = v
            
            keys = self.cache_k[:c_bsz, :start_pos+c_cl]
            v = self.cache_v[:c_bsz, :start_pos+c_cl]

        keys = repeat_kv(keys, self.num_rep)
        values = repeat_kv(v, self.num_rep)

        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        attention = torch.matmul(queries, keys.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))

        attention = torch.tril(attention[:, :, :c_cl, :c_cl])
        attention = attention.masked_fill(attention == 0, float("-inf"))

        attention = F.softmax(attention, dim=-1).type_as(queries)
        output = torch.matmul(attention, values)

        output = output.transpose(2, 1).contiguous().view(c_bsz, c_cl, c_dm)
        return self.wo(output)


class FeedForward(nn.Module):
    def __init__(self, num_dims, multiple_of, ffn_dim_multiplier=None):
        super().__init__()
        
        hidden_dim = 4 * num_dims
        hidden_dim = int(2 * num_dims / 3)

        if ffn_dim_multiplier is not None:
            hidden_dim = int(ffn_dim_multiplier * hidden_dim)

        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)

        self.w1 = nn.Linear(num_dims, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, num_dims, bias=False)
        self.w3 = nn.Linear(num_dims, hidden_dim, bias=False)

    def forward(self, x):
        return self.w2(F.silu(self.w1(x)) * self.w3(x))


class Block(nn.Module):
    def __init__(self, num_dims, multiple_of, num_heads, num_kv_heads, batch_size, context_len, use_cache=False, eps = 1e-6, ffn_dim_multiplier=None,device=device):
        super().__init__()

        self.attention = SelfAttention(num_dim=num_dims, num_heads=num_heads, num_kv_heads=num_kv_heads, batch_size=batch_size, context_len=context_len, use_cache=use_cache, device=device)
        self.ffn = FeedForward(num_dims=num_dims, multiple_of=multiple_of, ffn_dim_multiplier=ffn_dim_multiplier)

        self.norm_attention = RMSNorm(num_dims, eps=eps)
        self.norm_ffn = RMSNorm(num_dims, eps=eps)

    def forward(self, x, freqs_complex, start_pos):
        x = x + self.attention(
            self.norm_attention(x), 
            freqs_complex, 
            start_pos
            )
        
        x = x + self.ffn(
            self.norm_ffn(x)
            )
        return x

        
class Transformer(nn.Module):
    def __init__(self, device, vocab_size, batch_size, context_len, num_layers, num_dims, multiple_of, num_heads, num_kv_heads, use_cache=False, eps = 1e-6, ffn_dim_multiplier=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.num_dims = num_dims
        self.num_heads = num_heads
        self.num_layers = num_layers
        self.context_len = context_len

        self.tokens_embedding = nn.Embedding(self.vocab_size, self.num_dims)

        self.blocks = nn.ModuleList()
        for _ in range(self.num_layers):
            self.blocks.append(Block(
            num_dims=num_dims, multiple_of=multiple_of, num_heads=num_heads, num_kv_heads=num_kv_heads, batch_size=batch_size, context_len=context_len, use_cache=use_cache, eps = 1e-6, ffn_dim_multiplier=ffn_dim_multiplier, device=device
            ))

        self.norm = RMSNorm(self.num_dims, eps)
        self.ll_head = nn.Linear(self.num_dims, self.vocab_size, bias=False)

        self.tokens_embedding.weight = self.ll_head.weight

        self.freqs_complex = precompute_theta_pos_frequencies(self.num_dims // self.num_heads, self.context_len * 2, device=device)

    # I have taken this function [configure_optimizers] from Karpathy's nanoGPT
    # https://github.com/karpathy/nanoGPT
    def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
        # start with all of the candidate parameters
        param_dict = {pn: p for pn, p in self.named_parameters()}
        # filter out those that do not require grad
        param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
        # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
        # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
        decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
        nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
        optim_groups = [
            {'params': decay_params, 'weight_decay': weight_decay},
            {'params': nodecay_params, 'weight_decay': 0.0}
        ]
        num_decay_params = sum(p.numel() for p in decay_params)
        num_nodecay_params = sum(p.numel() for p in nodecay_params)
        print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
        print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
        # Create AdamW optimizer and use the fused version if it is available
        fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters
        use_fused = fused_available and device_type == 'cuda'
        extra_args = dict(fused=True) if use_fused else dict()
        optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args)
        print(f"using fused AdamW: {use_fused}")

        return optimizer

    def forward(self, x, targets=None, start_pos=0):
        _, seq_len = x.shape
        
        x = self.tokens_embedding(x) #
        freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
        
        for block in self.blocks:
            x = block(x, freqs_complex=freqs_complex, start_pos=start_pos)

        x = self.norm(x)
        logits = self.ll_head(x)
        
        
        if targets is None:
            loss = None
        else:
            c_bsz, c_cl, c_dm = logits.shape
            logits = logits.view(c_bsz*c_cl, c_dm)
            targets = targets.view(c_bsz*c_cl)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, x, max_tokens, use_cache=False):
        for c_tkn_pos in range(max_tokens):
            if use_cache:
                if c_tkn_pos == 0:
                    logits, _ = self.forward(x, start_pos=c_tkn_pos)
                else:
                    logits, _ = self.forward(x[:, -1], start_pos=c_tkn_pos)
            else:
                logits, _ = self.forward(x)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)
            x = torch.cat((x, next_token), dim=1)
        return x

In [10]:
print(123)

123


In [11]:
torch.set_float32_matmul_precision('high')

m = Transformer(
    device, vocab_size, mini_batches, context_len, num_layers=12, num_dims=1024, multiple_of=4, num_heads=16, num_kv_heads=4, use_cache=False, eps = 1e-6, ffn_dim_multiplier=None
)
m = m.to(device)
m = torch.compile(m)
#making loss average from all gpus
if ddp:
    m = DDP(m, device_ids=[ddp_local_rank]) 
raw_m = m.module if ddp else m

print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

data_loader = DataLoader(mini_batches, time_stamps, cur_process=ddp_rank, num_processes=ddp_world_size, data_dir=data_dir, split="train")
val_loader = DataLoader(mini_batches, time_stamps, cur_process=ddp_rank, num_processes=ddp_world_size, data_dir=data_dir, split="val")

# I have taken this function [configure_optimizers] from Karpathy's nanoGPT
max_lr = 6e-4
min_lr = max_lr * 0.1
warmup_lr_steps = 700
weight_decay = 0.1
beta1, beta2 = 0.9, 0.95

optmizer = raw_m.configure_optimizers(weight_decay, max_lr, (beta1, beta2), device)

127.108096 M parameters
loaded ~9900000000 tokens
loaded ~100000000 tokens
num decayed parameter tensors: 113, with 127,074,304 parameters
num non-decayed parameter tensors: 33, with 33,792 parameters
using fused AdamW: True


In [12]:
epochs = 19000
for epoch in range(epochs):
    t0 = time.time()
    last_epoch = epochs - 1
    # validation loss check + save model weights every 'checkpoints_frequency' steps
    if epoch % 300 == 0 or epoch == last_epoch:
        m.eval()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = 20
            for _ in range(val_loss_steps):
                x, y = val_loader.next_batch()
                x, y = x.to(device), y.to(device)
                with torch.autocast(device_type=device, dtype=torch.bfloat16):
                    logits, loss = m(x, y)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()
        if ddp:
            dist.all_reduce(val_loss_accum, op=dist.ReduceOp.AVG)
        if master_process:
            print(f"Validation loss: {val_loss_accum.item()}")
            with open(val_log_file, "a") as f:
                f.write(f"epoch:{epoch} val_loss:{val_loss_accum.item():.5f}\n")
            if epoch > 0 and (epoch % checkpoints_frequency == 0 or last_epoch):
                checkpoint_path = os.path.join(log_dir, f"model_{epoch:05d}.pt")
                checkpoint = {
                    'model': raw_m.state_dict(),
                    'optimizer':optmizer.state_dict(),
                    'epoch': epoch,
                    'val_loss': val_loss_accum.item()
                }

                torch.save(checkpoint, checkpoint_path)
        
    m.train()
    accumulated_loss = 0.0
    optmizer.zero_grad()
    # using accumulated loss
    for mini_epoch in range(mini_epochs):
        x, y = data_loader.next_batch()
        x, y = x.to(device), y.to(device)
    
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            logits, loss = m(x, y)
        loss /= mini_epochs
        accumulated_loss += loss.detach()
    
        if ddp:
            m.require_backward_grad_sync = (mini_epoch == mini_epochs-1)
        loss.backward()
    if ddp:
        dist.all_reduce(accumulated_loss, op=dist.ReduceOp.AVG)
    
    norm = torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
    # scheduler.step()

    # change lr
    lr = get_lr(epoch)
    for param_group in optmizer.param_groups:
        param_group['lr'] = lr
    optmizer.step()
    
    torch.cuda.synchronize()
    t1 = time.time()
    dt = t1-t0

    # wrtie to the file losses
    if master_process and epoch%5==0:
        print(f"epoch: {epoch}, loss: {accumulated_loss:.5f}, norm: {norm:.5f}, time: {dt*1000:.2f}ms, tok/s: {data_loader.B*data_loader.T*mini_epochs*ddp_world_size/dt:.2f}")
        with open(log_file, "a") as f:
            f.write(f"epoch:{epoch} loss:{accumulated_loss.item():.5f}\n")
if ddp:
    destroy_process_group()

OutOfMemoryError: CUDA out of memory. Tried to allocate 394.00 MiB. GPU 0 has a total capacity of 47.51 GiB of which 281.11 MiB is free. Process 431514 has 28.06 MiB memory in use. Process 738774 has 8.23 GiB memory in use. Process 2920184 has 11.73 GiB memory in use. Of the allocated memory 11.08 GiB is allocated by PyTorch, and 258.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [8]:
for i in range(20):
    t0 = time.time()
    print(enc.decode(raw_m.generate(torch.tensor(enc.encode("I am a large language model,")).to(device).view(1, -1), 100)[0].tolist()))
    print("without cache:", time.time()-t0)
    
for i in range(11):
    raw_m.blocks[i].attention.use_cache = True
    
try:
    raw_m.blocks[-1].attention.use_cache = True
except:
    pass
    
for i in range(20):
    t0 = time.time()
    print(enc.decode(raw_m.generate(torch.tensor(enc.encode("I am a large language model,")).to(device).view(1, -1), 100)[0].tolist()))
    print("with cache:",time.time()-t0)

/opt/conda/conda-bld/pytorch_1729647429097/work/aten/src/ATen/native/cuda/TensorCompare.cu:110: _assert_async_cuda_kernel: block: [0,0,0], thread: [0,0,0] Assertion `probability tensor contains either `inf`, `nan` or element < 0` failed.


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [36]:
for i in range(11):
    raw_m.blocks[i].attention.use_cache = True

In [32]:
# m.blocks[8].attention.use_cache

True

In [7]:
checkpoint_path = './log/model_00800.pt'  # Adjust path as needed
model_from_checkpoint = m = Transformer(
    device, vocab_size, mini_batches, context_len, num_layers=12, num_dims=1024, multiple_of=4, num_heads=16, num_kv_heads=4, use_cache=False, c_kv = None, eps = 1e-6, ffn_dim_multiplier=None
)
model_from_checkpoint = model_from_checkpoint.to(device)


if ddp:
    model_from_checkpoint = DDP(model_from_checkpoint, device_ids=[ddp_local_rank]) 
model_from_checkpoint = model_from_checkpoint.module if ddp else model_from_checkpoint

# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

state_dict = checkpoint['model']
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith("_orig_mod."):
        new_state_dict[k[len("_orig_mod."):]] = v 
    else:
        new_state_dict[k] = v

model_from_checkpoint.load_state_dict(new_state_dict, strict=False)
print(enc.decode(model_from_checkpoint.generate(torch.tensor(enc.encode("I am a large language model ")).to(device).view(1, -1), 50)[0].tolist()))


  checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))


I am a large language model ua Sheikh consentos aba have become the youngest so far vividly in that and is called a non-agulation of a work (NADB Assignment 2). Some test had an entirely new one (a smaller emphasis on misalination of novel
