# Chapter 3 - Pre-train a tiny LLM

In [7]:
from accelerate import Accelerator
import os
import torch
import bitsandbytes as bnb
import torch.nn as nn
from dataclasses import dataclass
import math

In [9]:
from datasets import load_dataset

In [10]:
dataset = load_dataset("roneneldan/TinyStories")
train_dataset      = dataset['train']
validation_dataset = dataset['validation']






Downloading readme:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


Downloading data:   0%|          | 0.00/249M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/246M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/248M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/9.99M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

In [40]:
samples = [sample for sample in train_dataset['text'][0:2]]
print("Input text sample \n")
print(samples)
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")
print(f"\n Vocabulary size {tokenizer.vocab_size}")
print(f"model max length {tokenizer.model_max_length}")
tokenizer.model_max_length = 50
print(f"new model max length {tokenizer.model_max_length}")

tokenizer.add_special_tokens({'pad_token': '[PAD]'})
encodings = tokenizer(samples, padding='max_length', max_length=250, truncation_strategy = "only_first")
print("\n Encodings \n")
print(encodings)


Input text sample 

['One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.', 'Once upon a time, there was a little car named Beep. Beep loved to go fast and play in the sun. Beep was a healthy car because he always had good fuel. Good fuel made Beep happy and strong.\n\nOne day, Beep was driving in the park when he saw a big tree. The tree had 

In [41]:
len(encodings['input_ids'][1])

250

In [12]:
from torch.utils.data import Dataset, DataLoader

def collate_fn(samples):
    encodings = tokenizer(samples, padding='max_length', max_length=250, truncation_strategy = "only_first")
    x = encodings['input_ids']
    


In [8]:
@dataclass
class SLLMConfig:
    context_window: int = 250
    vocab_size: int = 50304
    n_layers: int = 12
    n_head: int = 12
    n_embd: int = 256
    n_dkv: int = n_embd
    head_size: int = n_embd
    dropout: float = 0.2
    bias: bool = False

class SingleHeadAttention(nn.Module):
    """
    Implements weighted self attention
    """
    def __init__(self, config):

        super().__init__()
        self.Wq =  nn.Linear(config.n_embd, config.head_size, bias=config.bias)
        self.Wk =  nn.Linear(config.n_embd, config.head_size, bias=config.bias)
        self.Wv =  nn.Linear(config.n_embd, config.head_size, bias=config.bias)

        self.attn_drop = nn.Dropout(config.dropout)
        self.__init_weights()

    def __init_weights(self):

        nn.init.xavier_uniform(self.Wq.weights)
        nn.init.xavier_uniform(self.Wk.weights)
        nn.init.xavier_uniform(self.Wv.weights)

    def forward(self, x):

        q = self.Wq(x)
        k = self.Wk(x)
        v = self.Wv(x)

        attn = q @ k.transpose(-2,-1)
        mask = torch.triu(torch.ones(x.shape[-2], x.shape[-2],device=x.device), diagonal=1)
        masked = attn.masked_fill(mask.bool(), -torch.inf)
        attn = torch.softmax(masked / math.sqrt(k.shape[-1]), dim=1)
        attn = attn @ v
        attn = self.attn_drop(attn)

        return attn
        
        