In [2]:
!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz

--2024-04-19 13:30:07--  https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz
Resolving huggingface.co (huggingface.co)... 3.163.189.37, 3.163.189.90, 3.163.189.74, ...
Connecting to huggingface.co (huggingface.co)|3.163.189.37|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/42/7f/427f7497b6c6596c18b46d5a72e61364fcad12aa433c60a0dbd4d344477b9d81/26cf7605aca15bc4ea6fa637256400d9d01317b28ed296172b2d1dd160cd7699?response-content-disposition=attachment%3B+filename*%3DUTF-8%27%27TinyStories_all_data.tar.gz%3B+filename%3D%22TinyStories_all_data.tar.gz%22%3B&response-content-type=application%2Fgzip&Expires=1713817807&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcxMzgxNzgwN319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy80Mi83Zi80MjdmNzQ5N2I2YzY1OTZjMThiNDZkNWE3MmU2MTM2NGZjYWQxMmFhNDMzYzYwYTBkYmQ0ZDM0NDQ3N2I5ZDgxLzI2Y2Y3N

In [4]:
!mkdir TinyStories
!tar -xzf TinyStories_all_data.tar.gz -C TinyStories

In [2]:
import os
import glob
import json

In [4]:
shard_filenames = sorted(glob.glob(os.path.join('TinyStories', "*.json")))

In [5]:
with open(shard_filenames[0], "r") as f:
    data = json.load(f)

In [7]:
stories = [x['story'] for x in data]

In [8]:
stories[100]

'One day, a girl named Lily wanted to bake a cake. She put all the things she needed on the table. Her mom helped her mix everything in a big bowl. When the cake was ready, they put it in the oven to bake.\nWhile the cake was baking, Lily and her mom made an ornament. They made it look very pretty and attractive. They used colorful paper and shiny things. Lily said, "Mom, this will look nice on our tree!"\nWhen the cake was done, they opened the oven. But, there was a surprise! The cake was not a cake anymore. It turned into a big, friendly bear! The bear said, "Thank you for baking me! I will help you make more ornaments!" Lily, her mom, and the bear had lots of fun making pretty things together.'

In [10]:
text = text = "\n".join(stories)

In [12]:
len(text)

77586884

In [16]:
import torch

In [24]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(repr(''.join(chars)))
print(vocab_size)

'\t\n !"$%&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]`abcdefghijklmnopqrstuvwxyz|~\xa0éñ–—‘’“”…'
97


In [19]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}

In [20]:
def encode(s):
    return [stoi[c] for c in s]

def decode(ids):
    return ''.join([itos[i] for i in ids])


In [21]:
data = torch.tensor(encode(text), dtype = torch.long)

In [27]:
print(data[:100])
print(repr(text[:100]))

tensor([ 1,  1, 41, 67, 70, 83,  2, 59, 72, 62,  2, 31, 63, 72,  2, 59, 76, 63,
         2, 64, 76, 67, 63, 72, 62, 77, 15,  2, 49, 66, 63, 83,  2, 70, 67, 69,
        63,  2, 78, 73,  2, 74, 70, 59, 83,  2, 67, 72,  2, 78, 66, 63,  2, 74,
        59, 76, 69, 15,  2, 44, 72, 63,  2, 62, 59, 83, 13,  2, 78, 66, 63, 83,
         2, 77, 63, 63,  2, 59,  2, 60, 67, 65,  2, 78, 76, 63, 63,  2, 81, 67,
        78, 66,  2, 59,  2, 77, 81, 67, 72, 65])
'\n\nLily and Ben are friends. They like to play in the park. One day, they see a big tree with a swing'


In [28]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [32]:
block_size = 8
batch_size = 4

In [30]:
torch.manual_seed(1337)

<torch._C.Generator at 0x16a285630>

In [41]:
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [43]:
xb, yb = get_batch('train')

In [45]:
device = "mps"

In [50]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1137)

<torch._C.Generator at 0x16a285630>

In [51]:
class MoeLayer(nn.Module):
    def __init__(self, experts, gate, k=1):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.k = k

    def forward(self, inputs: torch.Tensor):
        # inputs shape: [batch_size, seq_len, hidden_size]
        inputs_squashed = inputs.view(-1, inputs.shape[-1])
        # inputs_squashed shape: [batch_size * seq_len, hidden_size]
        gate_logits = self.gate(inputs_squashed)
        weights, selected_experts = torch.topk(gate_logits, self.k)
        # weights shape: [batch_size * seq_len, k]
        # selected_experts shape: [batch_size * seq_len, k]
        weights = nn.functional.softmax(
            weights,
            dim=1,
            dtype=torch.float,
        ).type_as(inputs)
        results = torch.zeros_like(inputs_squashed)
        for i, expert in enumerate(self.experts):
            batch_idx, nth_expert = torch.where(selected_experts == i)
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
                inputs_squashed[batch_idx]
            )
        return results.view_as(inputs)

In [52]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MulitHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x =  torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(x))
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head, num_experts=4):
        super().__init__()
        self.sa_head= MulitHeadAttention(n_head, n_embed//n_head)
        self.ffw = MoeLayer(
            experts=[FeedForward(n_embed) for _ in range(num_experts)],
            gate=nn.Linear(n_embed, num_experts, bias=False),
        )

#         self.ffw=  FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x+self.ffw(self.ln2(x))
        return x


class Transformer(nn.Module):
    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embed, device=device)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T).to(device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [57]:
# hyperparameters
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'mps'
eval_iters = 200
n_embed = 384
n_head = 6
n_layer = 6
dropout = 0.0
# ------------

In [58]:
model = Transformer()
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [59]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X = X.to(device)
            Y = Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [60]:
device

'mps'

In [61]:
model = model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)

In [63]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % 100 == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}/{max_iters}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')
    xb = xb.to(device)
    yb = yb.to(device)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

In [None]:
d = 'once upon a time there was a '
x = torch.tensor(encode(d), dtype = torch.long,device=device).unsqueeze(0)
print(decode(model.generate(x, max_new_tokes=500)[0].tolist()))

In [None]:
d = 'One day'
x = torch.tensor(encode(d), dtype = torch.long,device=device).unsqueeze(0)
print(decode(model.generate(x, max_new_tokes=500)[0].tolist()))