In [1]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math
from matplotlib import pyplot as plt
from dataclasses import dataclass
import time
import pandas as pd
from typing import List, Tuple, Optional

In [2]:
@dataclass
class Args:
    # file_path = 'data/TinyStories-valid.txt'
    # tokenizer_path = "tokenizer.model"
    # frequency_cutoff=25

    batch_size: int = 32
    context_window: int = 16
    # early_stopping_criteria :int = 5
    # learning_rate: float = 0.001
    n_layers: int = 18
    n_heads: int = 8
    vocab_size: int = -1  # set after tokenizer is loaded
    d_model: int = 128
    epochs: int = 10000
    log_interval: int = 10
    max_seq_len: int = 512
    ffn_dim_multiplier: Optional[int] = None
    multiple_of: int = 256
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    # cuda: bool = True

    def __repr__(self) -> str:
        # print all the attributes of the class and theirn values
        return '\n'.join([f'{k} : {v}' for k, v in self.__dict__.items()])
    
    def __str__(self) -> str:
        return self.__repr__()
Args = Args()

In [3]:
Args

batch_size : 32
context_window : 16
n_layers : 18
n_heads : 8
vocab_size : -1
d_model : 128
epochs : 10000
log_interval : 10
max_seq_len : 512
ffn_dim_multiplier : None
multiple_of : 256
device : cuda

In [4]:
lines = open('./data/TinyStories-valid.txt', encoding="utf8").read()
# lines = lines[:1115394]

vocab = sorted(list(set(lines)))
itos = {i:ch for i, ch in enumerate(vocab)}
stoi = {ch:i for i, ch in enumerate(vocab)}

print(lines[:30])

 Spot. Spot saw the shiny car 


In [5]:
# simple tokenization by characters
def encode(s):
    return [stoi[ch] for ch in s]

def decode(l):
    return ''.join([itos[i] for i in l])

print('vocab size:', len(vocab))
decode(encode("hello"))

vocab size: 98


'hello'

In [6]:
dataset = torch.tensor(encode(lines), dtype=torch.int8)
dataset.shape

torch.Size([19432979])

In [7]:
def get_batches(data, split, batch_size, context_window, args=Args):
    train = data[:int(.8 * len(data))]
    val = data[int(.8 * len(data)): int(.9 * len(data))]
    test = data[int(.9 * len(data)):]

    batch_data = train
    if split == 'val':
        batch_data = val

    if split == 'test':
        batch_data = test

    # pick random starting points
    ix = torch.randint(0, batch_data.size(0) - context_window - 1, (batch_size,))
    x = torch.stack([batch_data[i:i+context_window] for i in ix]).long()
    y = torch.stack([batch_data[i+1:i+context_window+1] for i in ix]).long()
    return x.to(device=args.device), y.to(device=args.device)


In [8]:
@torch.no_grad()  # don't compute gradients for this function
def evaluate_loss(model, args=Args):
    out = {}
    model.eval()
    for split in ["train", "val"]:
        losses = []
        for _ in range(10):
            xb, yb = get_batches(dataset, split, args.batch_size, args.context_window, args=args)
            _, loss = model(xb, yb)
            losses.append(loss.item())
        out[split] = np.mean(losses)
    model.train()
    return out

In [9]:


class SimpleBrokenModel(nn.Module):
    def __init__(self, args=Args):
        super().__init__()
        self.args = args

        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.linear = nn.Sequential(
            nn.Linear(args.d_model, args.d_model),
            nn.ReLU(),
            nn.Linear(args.d_model, args.vocab_size),
        )

        print("model params:", sum([m.numel() for m in self.parameters()]))

    def forward(self, idx, targets=None):
        x = self.embedding(idx)
        logits = self.linear(x)
        # logits = F.softmax(a, dim=-1)

        if targets is not None:
            loss = loss = F.cross_entropy(logits.view(-1, self.args.vocab_size), targets.view(-1))
            return logits, loss

        else:
            return logits

Args.d_model = 128


# model = SimpleBrokenModel(Args)
# xs, ys = get_batches(dataset, 'train', Args.batch_size, Args.context_window)

# logits, loss = model(xs, ys)

In [10]:

Args.vocab_size = len(vocab)
Args.batch_size = 32
Args.context_window = 16
# Args.epochs = 1000
Args.log_interval = 10

model = SimpleBrokenModel(Args)

optimizer = torch.optim.Adam(
    model.parameters(), 
)

def train(model, optimizer, scheduler=None, args=Args, print_logs=False):
    losses = []
    start_time = time.time()
    for epoch in range(args.epochs):
        print(f"Epoch {epoch}")
        optimizer.zero_grad()

        xs, ys = get_batches(dataset, 'train', args.batch_size, args.context_window)
        # xs = xs.to(args.device)
        # ys = ys.to(args.device)
        print(xs.device)
        logits, loss = model(xs, targets=ys)
        loss.backward()
        optimizer.step()

        if scheduler:
            scheduler.step()

        if epoch % args.log_interval == 0:
            batch_time = time.time() - start_time
            x = evaluate_loss(model)
            losses += [x]
            if print_logs:
                print(f"Epoch {epoch} | val loss {x['val']:.3f} | Time {batch_time:.3f} | ETA in seconds {batch_time * (args['epochs'] - epoch)/args['log_interval'] :.3f}")
            start_time = time.time()

            if scheduler:
                print("lr: ", scheduler.get_lr())

    print("validation loss: ", losses[-1]['val'])
    return pd.DataFrame(losses).plot()

# train(model, optimizer)

model params: 41698


  from .autonotebook import tqdm as notebook_tqdm


In [11]:
def generate(model, args=Args, max_new_tokens=30):
    idx = torch.zeros(5, 1).long()
    for _ in range(max_new_tokens):
        # call the model
        logits = model(idx[:, -args.context_window:])
        last_time_step_logits = logits[
            :, -1, :
        ]  # all the batches (1), last time step, all the logits
        p = F.softmax(last_time_step_logits, dim=-1)  # softmax to get probabilities
        idx_next = torch.multinomial(
            p, num_samples=1
        )  # sample from the distribution to get the next token
        idx = torch.cat([idx, idx_next], dim=-1)  # append to the sequence
    return [decode(x) for x in idx.tolist()]

# generate(model)

### RMS-Norm

In [12]:
# The OG RMS Norm form the llama inference repo 
class RMSNorm(nn.Module):
    def __init__(self, dim:int, eps:float = 1e-8):
        super().__init__()
        self.eps = eps

        self.weight = nn.Parameter(torch.ones(dim))
    
    def _norm(self, x:torch.Tensor):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True))#+self.eps)
    
    def forward(self, x:torch.Tensor):
        # print(self.weight.shape, x.shape, self._norm(x.float()).shape)
        return self.weight * self._norm(x.float()).type_as(x)


### ROPE -> Rotary Positional Embeddings Positional Embeddings

In [13]:
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0

    theta_numerator = torch.arange(0, head_dim, 2).float()
    # shape (head_dim /2)
    theta = 1.0 / (theta ** (theta_numerator/head_dim)).to(device)
    m = torch.arange(seq_len, device=device)
    # multiply each theta by each position using the outer product
    # shape: seqlen outer prodcut head_dim / 2 -> (seq_len, head_dim /2)
    freqs = torch.outer(m, theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_complex

def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    
    # print(f"X-Shape: {x.shape}")
    # print(f"Freqs-Shape: {freqs_complex.shape}")
    
    cntxt_length = x.shape[1]
    
    # print(f"Context-Length: {context_length}")
    
    # (B, Seq_len, H, Head_dim) -> (B, Seq_len, H, Head_dim / 2)
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1 ,2))
    
    # print(f"X-Complex-Shape: {x_complex.shape}")
    
    # seq_len, Head_dim/2 -> (1, Seq_len, 1, Head_dim / 2)
    freqs_complex = freqs_complex[:cntxt_length,:]
    
    # print(f"Freqs-Complex-Limited-Shape: {freqs_complex_limited.shape}")
    
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    
    # print(f"Freqs-Complex-Shape-After-Unsqueeze: {freqs_complex.shape}")

    # (B, Seq_len, H, Head_dim /2) * (1, Seq_len, 1, Head_dim / 2) = (B, Seq_len, H, Head_dim / 2)
    x_rotated = x_complex * freqs_complex
    
    # print(f"X-Rotated-Shape: {x_rotated.shape}")
    
    # (B,seq_len, H, Head_dim/2) -> (B, Seq_len, H, Head_dim/2, 2)
    x_out = torch.view_as_real(x_rotated)

    # print(f"X-Out-Shape: {x_out.shape}")
    # print(f"Reshape X-out to x:{x.shape}")
    
    # (B, seq_len, H, Head_dim/2, 2 ) -> (B, sseq_len, H, Head_dim)
    x_out = x_out.reshape(*x.shape)
    return x_out.type_as(x).to(device)

In [14]:

class SelfAttention(nn.Module):
    def __init__(self, args):
        super().__init__()

        self.n_heads = args.n_heads

        self.n_rep = self.n_heads // self.n_heads
        self.head_dim = args.d_model // args.n_heads

        self.wq = nn.Linear(args.d_model, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.d_model, self.n_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.d_model, self.n_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads*self.head_dim, args.d_model, bias=False)


    def forward(self, x:torch.Tensor, freqs_complex:torch.Tensor):
        batch_size, seq_len, _ = x.shape  # (B, 1, Dim)

        # (B, 1, Dim) -> (B, 1, H_Q * Head_Dim)
        xq = self.wq(x)
        # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
        xk = self.wk(x)
        # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim)
        xv = self.wv(x)

        # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim)
        xq = xq.view(batch_size, seq_len, self.n_heads, self.head_dim)
        # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
        xk = xk.view(batch_size, seq_len, self.n_heads, self.head_dim)
        # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim)
        xv = xv.view(batch_size, seq_len, self.n_heads, self.head_dim)
        #apply rotatory embeddings to the keys and values
        xq = apply_rotary_embeddings(xq, freqs_complex, x.device)
        xk = apply_rotary_embeddings(xk, freqs_complex, x.device)

        xq = xq.transpose(1,2)
        keys = xk.transpose(1,2)
        values = xv.transpose(1,2)

        scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
        scores =  scores.masked_fill_(torch.triu(torch.ones_like(scores), diagonal=1) == 1, float('-inf'))
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)

        output = torch.matmul(scores, values)
        output = output.transpose(1,2).contiguous().view(batch_size, seq_len, -1)
        return self.wo(output)

# layer = SelfAttention(Args)
# batch = torch.randn((Args.batch_size, Args.context_window, Args.d_model))
# freqs_complex = precompute_theta_pos_frequencies(Args.d_model // Args.n_heads, Args.context_window, device = "cpu")
# output= layer(batch, freqs_complex)

In [15]:
class FeedForward(nn.Module):
    def __init__(self, args):
        super().__init__()

        hidden_dim = args.d_model * 4
        hidden_dim = int(2 * hidden_dim / 3)

        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)

        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)

        self.w1 = nn.Linear(args.d_model, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.d_model, bias=False)
        self.w3 = nn.Linear(args.d_model, hidden_dim, bias=False)


    def forward(self, x:torch.Tensor):
        # return self.w2(self.w3(x) * F.silu(self.w1(x)))
        swish = F.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)
        return x
    
class EncoderBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_heads = args.n_heads
        self.head_dim = args.d_model // args.n_heads
        
        self.attention = SelfAttention(args)
        self.feed_forward = FeedForward(args)

        self.attention_norm = RMSNorm(args.d_model) #, eps =args.norm_eps)

        self.ffn_norm = RMSNorm(args.d_model) #, eps=args.norm_eps)

        # self.freqs_complex = precompute_theta_pos_frequencies(Args.d_model // Args.n_heads, Args.context_window, device = "cpu")
    
    def forward(self, x:torch.Tensor, freqs_complex: torch.Tensor):
        h = x + self.attention.forward(self.attention_norm(x),  freqs_complex)
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [None]:
Args.vocab_size = len(vocab)
class RopeModel(nn.Module):
    def __init__(self, args: Args):
        super().__init__()
        self.args = args
        self.freqs_complex = precompute_theta_pos_frequencies(Args.d_model // Args.n_heads, Args.context_window, device = Args.device)
        
        #! new verison starts here
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)

        self.layers = nn.ModuleList([])
        for _ in range(args.n_layers):
            self.layers.append(EncoderBlock(args))

        self.norm = RMSNorm(args.d_model) #, eps=args.norm_eps)
        self.output = nn.Linear(args.d_model, self.vocab_size , bias=False)

        print("model params:", sum([m.numel() for m in self.parameters()]))

    def forward(self, idx, targets=None):
 
        # print(idx.shape)
        h = self.embedding(idx)

        #Consecutively apply all the encoder lauers
        for layer in self.layers:
            h = layer(h, self.freqs_complex)
        h = self.norm(h)
        logits = self.output(h).float()

        
        # print(x.shape)
        # one block of attention
        # x = self.rms(x) # rms pre-normalization
        # # print(x.shape)
        # x = x + self.rope_attention(x, self.freqs_complex)
        # x = self.rms(x) # rms pre-normalization
        # x = x + self.linear(x)

        # logits = self.last_linear(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, self.args.vocab_size), targets.view(-1))
            return logits, loss

        else:
            return logits

model = RopeModel(Args)
model.to(Args.device)
optimizer = torch.optim.Adam(model.parameters())
train(model, optimizer)

In [None]:
Args

batch_size : 32
context_window : 16
n_layers : 18
n_heads : 8
vocab_size : 98
d_model : 128
epochs : 10000
log_interval : 10
max_seq_len : 512
ffn_dim_multiplier : None
multiple_of : 256
device : cuda

In [None]:
def generate(model, args=Args, max_new_tokens=30):
    idx = torch.zeros(5, 1).long().to(args.device)
    for _ in range(max_new_tokens):
        # call the model
        logits = model(idx[:, -args.context_window:])
        last_time_step_logits = logits[
            :, -1, :
        ]  # all the batches (1), last time step, all the logits
        p = F.softmax(last_time_step_logits, dim=-1)  # softmax to get probabilities
        idx_next = torch.multinomial(
            p, num_samples=1
        )  # sample from the distribution to get the next token
        idx = torch.cat([idx, idx_next], dim=-1)  # append to the sequence
    return [decode(x) for x in idx.tolist()]

generate(model, Args, max_new_tokens=100)

['\nLook out caughts. She saw a friends said.\n"Ow, let\'s get. You are wise how the duck was wheir new th',
 '\nDaddy, a time, there was a cown coud truck. He yells her friends were happy to strong time in the wa',
 '\nBut, this knew his family ran towers, Samong in the sky with her and even said it we see his wan and',
 '\nJack and SuiHa thought I for them. They knew the driver actor. "I think you appoo!"\nBen aid, so, the',
 '\n<|endoftext|>\nOnce upon a time, there was a woman licked each other. She saw in an earch and.\n<|endo']

In [None]:
idx = torch.zeros(5, 1).long()
model_input = idx[:, -Args.context_window:]
model_input.shape

torch.Size([5, 1])

In [None]:
model.parameters

<bound method Module.parameters of RopeModel(
  (embedding): Embedding(98, 128)
  (layers): ModuleList(
    (0-17): 18 x EncoderBlock(
      (attention): SelfAttention(
        (wq): Linear(in_features=128, out_features=128, bias=False)
        (wk): Linear(in_features=128, out_features=128, bias=False)
        (wv): Linear(in_features=128, out_features=128, bias=False)
        (wo): Linear(in_features=128, out_features=128, bias=False)
      )
      (feed_forward): FeedForward(
        (w1): Linear(in_features=128, out_features=512, bias=False)
        (w2): Linear(in_features=512, out_features=128, bias=False)
        (w3): Linear(in_features=128, out_features=512, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_features=128, out_features=98, bias=False)
)>

In [None]:
Args.context_window

16