In [1]:
# %pip install wandb
# %pip install datasets
# %pip install tokenizers
# %pip install -U transformers datasets tokenizers accelerate

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import math
import time

In [2]:
from torch.nn import RMSNorm
from torch.utils.data import DataLoader
from datasets import load_dataset
from dataclasses import dataclass
from tokenizers import Tokenizer
from transformers import AutoTokenizer
from pathlib import Path
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
@dataclass
class modelargs:
    block_size = 512
    batch_size = 32
    embeddings_dim = 512
    attn_dropout = 0.1
    vocab_size = 50257 #gpt2 bpe
    ignore_pad_token_in_loss = False #using packed blocks so padding never happens
    heads = 8 
    dropout = 0.1
    epochs = 1
    max_lr = 6e-4
    decoder_layers = 6
    weight_decay = 0.1
    beta1 = 0.9
    beta2 = 0.95
    clip=1.0
    experts = 8
    top_experts = 2
    use_shared_experts = True
    base_freq = 100000
    noisy_topk = False
    eps: float = 1e-8
    loss_scale = 0.03 
    use_aux_free_load_balancing = True
    aux_free_bias_update_rate = 0.001
    mtp_heads = 1
    latent_dim = 64

In [4]:
dataset = load_dataset('roneneldan/TinyStories')

In [5]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2', use_fast=True)
tokenizer.pad_token = tokenizer.eos_token #for batching

In [6]:
# vocab
vocab_size = tokenizer.vocab_size
vocab_size

50257

In [7]:
def encode(text: str):
    return tokenizer.encode(text, add_special_tokens = False)

In [6]:
artifacts = Path('artifacts')
artifacts.mkdir(exist_ok=True)

tokenized_dir = artifacts/'tokenized'
tokenized_dir.mkdir(exist_ok=True)

In [9]:
def chunked_tokenize(texts, chunk_size=10000):
    all_ids = []
    for i in tqdm(range(0, len(texts), chunk_size), desc='Tokenizing'):
        chunk = texts[i:i + chunk_size]
        chunk_text = '\n'.join(chunk)
        chunk_ids = encode(chunk_text)
        all_ids.extend(chunk_ids)
    return torch.tensor(all_ids, dtype=torch.long)

In [7]:
train_path = tokenized_dir / 'train_ids.pt'
val_path = tokenized_dir / 'val_ids.pt'

In [None]:
train_ids = chunked_tokenize(dataset['train']['text'])
val_ids = chunked_tokenize(dataset['validation']['text'])

torch.save(train_ids, train_path)
torch.save(val_ids, val_path)

Tokenizing:   0%|          | 0/212 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (2162077 > 1024). Running this sequence through the model will result in indexing errors
Tokenizing: 100%|██████████| 212/212 [55:26<00:00, 15.69s/it]
Tokenizing: 100%|██████████| 3/3 [00:35<00:00, 11.79s/it]


In [8]:
train_ids = torch.load(train_path)
val_ids   = torch.load(val_path)

In [9]:
# converting flat token stream to blocks
class CausalDataset(torch.utils.data.Dataset):
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx):
        x = self.tokens[idx : idx + self.block_size]
        y = self.tokens[idx + 1 : idx + self.block_size + 1]
        return x, y

In [10]:
# train val datasets
train_dataset = CausalDataset(train_ids, modelargs.block_size)
val_dataset   = CausalDataset(val_ids, modelargs.block_size)

In [11]:
train_loader = DataLoader(
    train_dataset,
    batch_size=modelargs.batch_size,
    shuffle=True,
    drop_last=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=modelargs.batch_size,
    shuffle=False,
    drop_last=True
)

#### **SinusoidalPositionalEmbeddings**

In [12]:
def precompute_pos_embeddings(block_size=modelargs.block_size, embeddings_dim=modelargs.embeddings_dim, device='cuda'):
    pe = torch.zeros(block_size, embeddings_dim, device=device)
    position = torch.arange(0, block_size, dtype=torch.float, device=device).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embeddings_dim, 2, dtype=torch.float, device=device) * (-math.log(10000.0) / embeddings_dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe.unsqueeze(0)  # (1, block_size, embeddings_dim)

In [13]:
def apply_pos_embeddings(x, pe):
    # x: (batch, seq_len, embeddings_dim)
    return x + pe[:, :x.size(1)]

#### **RMS Normalization**

In [None]:
class Normalization(nn.Module):
    def __init__(self, embeddings_dim: int = modelargs.embeddings_dim):
        super().__init__()
        self.rmsnorm_layer = RMSNorm(embeddings_dim)

    def forward(self, x):
        return self.rmsnorm_layer(x)