In [122]:
# !wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz

In [123]:
import os
import glob
import json

In [124]:
if not os.path.exists("TinyStories_all_data.tar.gz"):
    !wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz
if not os.path.exists("TinyStories"):
    !mkdir TinyStories
    !tar -xzf TinyStories_all_data.tar.gz -C TinyStories

In [125]:
shard_filenames = sorted(glob.glob(os.path.join('TinyStories', "*.json")))

In [126]:
with open(shard_filenames[0], "r") as f:
    data = json.load(f)

In [127]:
s = set(f for x in data for f in x['instruction']['features'])
s

{'BadEnding', 'Conflict', 'Dialogue', 'Foreshadowing', 'MoralValue', 'Twist'}

In [128]:
stories = [r['story'] for r in data]
badEndings = [r['story'] for r in data if 'BadEnding' in r['instruction']['features']]
conflicts = [r['story'] for r in data if 'Conflict' in r['instruction']['features']]

In [129]:
stories[100]

'One day, a girl named Lily wanted to bake a cake. She put all the things she needed on the table. Her mom helped her mix everything in a big bowl. When the cake was ready, they put it in the oven to bake.\nWhile the cake was baking, Lily and her mom made an ornament. They made it look very pretty and attractive. They used colorful paper and shiny things. Lily said, "Mom, this will look nice on our tree!"\nWhen the cake was done, they opened the oven. But, there was a surprise! The cake was not a cake anymore. It turned into a big, friendly bear! The bear said, "Thank you for baking me! I will help you make more ornaments!" Lily, her mom, and the bear had lots of fun making pretty things together.'

In [130]:
text = text = "\n".join(stories)
badEndings_text = "\n".join(badEndings)
conflicts_text = "\n".join(conflicts)

In [131]:
len(text)

77586884

In [132]:
import torch

In [133]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(repr(''.join(chars)))
print(vocab_size)

'\t\n !"$%&\'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]`abcdefghijklmnopqrstuvwxyz|~\xa0éñ–—‘’“”…'
97


In [134]:
stoi = { ch:i for i, ch in enumerate(chars) }
itos = {i:ch for i,ch in enumerate(chars)}

In [135]:
def encode(s):
    return [stoi[c] for c in s]

def decode(ids):
    return ''.join([itos[i] for i in ids])


In [136]:
data = torch.tensor(encode(text), dtype = torch.long)
badEndings_data = torch.tensor(encode(badEndings_text), dtype = torch.long)
conflicts_data = torch.tensor(encode(conflicts_text), dtype = torch.long)

In [137]:
print(data[:100])
print(repr(text[:100]))

tensor([ 1,  1, 41, 67, 70, 83,  2, 59, 72, 62,  2, 31, 63, 72,  2, 59, 76, 63,
         2, 64, 76, 67, 63, 72, 62, 77, 15,  2, 49, 66, 63, 83,  2, 70, 67, 69,
        63,  2, 78, 73,  2, 74, 70, 59, 83,  2, 67, 72,  2, 78, 66, 63,  2, 74,
        59, 76, 69, 15,  2, 44, 72, 63,  2, 62, 59, 83, 13,  2, 78, 66, 63, 83,
         2, 77, 63, 63,  2, 59,  2, 60, 67, 65,  2, 78, 76, 63, 63,  2, 81, 67,
        78, 66,  2, 59,  2, 77, 81, 67, 72, 65])
'\n\nLily and Ben are friends. They like to play in the park. One day, they see a big tree with a swing'


In [138]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

badEndings_n = int(0.9*len(badEndings_data))
badEndings_train_data = badEndings_data[:badEndings_n]
badEndings_val_data = badEndings_data[badEndings_n:]

conflicts_n = int(0.9*len(conflicts_data))
conflicts_train_data = conflicts_data[:conflicts_n]
conflicts_val_data = conflicts_data[conflicts_n:]

In [139]:
block_size = 8
batch_size = 4

In [140]:
torch.manual_seed(1337)

<torch._C.Generator at 0x142454910>

In [141]:
train_corpus = [train_data, badEndings_train_data, conflicts_train_data]
val_corpus = [val_data, badEndings_val_data, conflicts_val_data]

In [142]:
def get_batch(split, domain):
    # generate a small batch of data of inputs x and targets y
    data = train_corpus[domain] if split == 'train' else val_corpus[domain]
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [143]:
xb, yb = get_batch('train', 0)

In [144]:
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1137)

<torch._C.Generator at 0x142454910>

In [145]:
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias = False)
        self.query = nn.Linear(n_embed, head_size, bias = False)
        self.value = nn.Linear(n_embed, head_size, bias = False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout(wei)
        v = self.value(x)
        out = wei @ v
        return out

class MulitHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x =  torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(x))
        return out


class FeedForward(nn.Module):
    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4* n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
         nn.Dropout(dropout))

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    def __init__(self, n_embed, n_head, moe, num_experts=4):
        super().__init__()
        self.sa_head= MulitHeadAttention(n_head, n_embed//n_head)
        self.ffw = FeedForward(n_embed)

        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x+self.ffw(self.ln2(x))
        return x


class Transformer(nn.Module):
    def __init__(self, moe=False):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embed, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embed, device=device)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head=n_head, moe=moe) for _ in range(n_layer)])
        self.lm_head = nn.Linear(n_embed, vocab_size)


    def forward(self, idx, targets=None):
        B, T = idx.shape

        token_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T).to(device))
        x = token_emb + pos_emb
        x = self.blocks(x)
        logits = self.lm_head(x)
        if targets == None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokes):
        for _ in range(max_new_tokes):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim = -1)
            idx_next = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, idx_next), dim = 1)
        return idx

In [198]:
# hyperparameters
batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 100
learning_rate = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embed = 384
n_head = 6
n_layer = 6
dropout = 0.0
# ------------

In [199]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, 0)
            X = X.to(device)
            Y = Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [200]:
device

'cpu'

In [201]:
model = Transformer().to(device)
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)

In [202]:
def train(model, optimizer, domain):
    for iter in range(max_iters):

        # every once in a while evaluate the loss on train and val sets
        if iter % 100 == 0 or iter == max_iters - 1:
            losses = estimate_loss()
            print(f"step {iter}/{max_iters}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # sample a batch of data
        xb, yb = get_batch('train', domain)
        xb = xb.to(device) # (batch_size, block_size)
        yb = yb.to(device)

        # compute the loss and update the model
        logits, loss = model(xb, yb)
        # logits shape: (batch_size * block_size, vocab_size)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

train(model, optimizer, 0)

step 0/5000: train loss 5.0647, val loss 5.0637
step 100/5000: train loss 2.3834, val loss 2.3868
step 200/5000: train loss 2.3269, val loss 2.3282
step 300/5000: train loss 2.3055, val loss 2.3075
step 400/5000: train loss 2.2845, val loss 2.2847
step 500/5000: train loss 2.2615, val loss 2.2587
step 600/5000: train loss 2.2237, val loss 2.2236
step 700/5000: train loss 2.1778, val loss 2.1759
step 800/5000: train loss 2.0994, val loss 2.1024
step 900/5000: train loss 1.9945, val loss 1.9948
step 1000/5000: train loss 1.8774, val loss 1.8794
step 1100/5000: train loss 1.7949, val loss 1.7978
step 1200/5000: train loss 1.7221, val loss 1.7329
step 1300/5000: train loss 1.6702, val loss 1.6718
step 1400/5000: train loss 1.6186, val loss 1.6233
step 1500/5000: train loss 1.5703, val loss 1.5746
step 1600/5000: train loss 1.5385, val loss 1.5410
step 1700/5000: train loss 1.5080, val loss 1.5089
step 1800/5000: train loss 1.4831, val loss 1.4842
step 1900/5000: train loss 1.4464, val loss

In [203]:
import copy
import pandas as pd
import numpy as np


model1 = copy.deepcopy(model)
optimizer1 = torch.optim.AdamW(model1.parameters(),lr=learning_rate)
model2 = copy.deepcopy(model)
optimizer2 = torch.optim.AdamW(model2.parameters(),lr=learning_rate)

In [204]:
train(model1, optimizer1, 1)

step 0/5000: train loss 1.1008, val loss 1.0971
step 100/5000: train loss 1.1013, val loss 1.1000
step 200/5000: train loss 1.1094, val loss 1.1012
step 300/5000: train loss 1.0993, val loss 1.1007
step 400/5000: train loss 1.1002, val loss 1.0983
step 500/5000: train loss 1.0967, val loss 1.0961
step 600/5000: train loss 1.0971, val loss 1.0986
step 700/5000: train loss 1.0995, val loss 1.1041
step 800/5000: train loss 1.1012, val loss 1.0965
step 900/5000: train loss 1.1043, val loss 1.1004
step 1000/5000: train loss 1.0905, val loss 1.1004
step 1100/5000: train loss 1.1003, val loss 1.1028
step 1200/5000: train loss 1.1006, val loss 1.0981
step 1300/5000: train loss 1.0959, val loss 1.1004
step 1400/5000: train loss 1.0963, val loss 1.0987
step 1500/5000: train loss 1.1000, val loss 1.1020
step 1600/5000: train loss 1.0954, val loss 1.0964
step 1700/5000: train loss 1.1009, val loss 1.1018
step 1800/5000: train loss 1.0926, val loss 1.1044
step 1900/5000: train loss 1.0981, val loss

In [205]:
train(model2, optimizer2, 2)

step 0/5000: train loss 1.1024, val loss 1.0976
step 100/5000: train loss 1.1028, val loss 1.0981
step 200/5000: train loss 1.0919, val loss 1.1012
step 300/5000: train loss 1.0911, val loss 1.1030
step 400/5000: train loss 1.0985, val loss 1.0986
step 500/5000: train loss 1.1037, val loss 1.1061
step 600/5000: train loss 1.0991, val loss 1.0978
step 700/5000: train loss 1.1056, val loss 1.1010
step 800/5000: train loss 1.0994, val loss 1.0986
step 900/5000: train loss 1.0922, val loss 1.0990
step 1000/5000: train loss 1.0953, val loss 1.0973
step 1100/5000: train loss 1.0988, val loss 1.0995
step 1200/5000: train loss 1.0986, val loss 1.1028
step 1300/5000: train loss 1.0966, val loss 1.0982
step 1400/5000: train loss 1.0945, val loss 1.0984
step 1500/5000: train loss 1.0988, val loss 1.1010
step 1600/5000: train loss 1.0965, val loss 1.1071
step 1700/5000: train loss 1.0919, val loss 1.1042
step 1800/5000: train loss 1.0979, val loss 1.1041
step 1900/5000: train loss 1.0958, val loss

In [206]:
prior = torch.ones(2, device=device) / 2

In [207]:
def posterior(expert1, expert2, x, prior):
    # prior shape: [2]
    # x shape: [1, seq_len]
    _, loss1 = expert1(x, x)  # shape: a number
    _, loss2 = expert2(x, x)
    logPost = torch.stack([-loss1, -loss2], dim=0) + torch.log(prior)  # shape: [2]
    post = torch.exp(logPost)  # shape: [2]
    post = post / post.sum(dim=0)
    return post  # shape: [2]


@torch.no_grad()
def generate(expert1, expert2, x, max_new_tokes):
    global prior
    expert1.eval()
    expert2.eval()
    expert_probs_all = []
    for _ in range(max_new_tokes):
        xb = x[:, -block_size:]
        logits1, _ = expert1(xb)
        logits1 = logits1[:, -1, :]
        probs1 = F.softmax(logits1, dim=-1)  # shape: [batch_size, vocab_size]

        logits2, _ = expert1(xb)
        logits2 = logits2[:, -1, :]
        probs2 = F.softmax(logits2, dim=-1)

        post = posterior(expert1, expert2, xb, prior)  # shape: [batch_size, 2]
        expert_probs_all.append(post)

        # print(probs1.shape, post.shape, post[0].shape)
        probs = probs1 * post[0] + probs2 * post[1]

        idx_next = torch.multinomial(probs, num_samples=1)
        x = torch.cat((x, idx_next), dim=1)

    # update prior using exponential moving average
    prior = (
        pd.DataFrame(expert_probs_all)
        .ewm(alpha=0.3, adjust=False)
        .mean()
        .tail(n=1)
        .to_numpy()
        .squeeze(0)
    )
    prior = torch.Tensor(prior).to(device)
    expert1.train()
    expert2.train()
    return x

In [208]:
d = 'once upon a time there was a '
x = torch.tensor(encode(d), dtype = torch.long,device=device).unsqueeze(0)
print(decode(generate(model1, model2, x, max_new_tokes=500)[0].tolist()))

once upon a time there was a box trunning and ran him. She was very sad that she could not like the room what he wanted her to go to the hospital. She felt sad and cried. Her mom had mean with her mommy and firsts bed. The end.
Once upon a time, there was a little girl named Lily. She had a necklace and showed the tree. She acidentally not put herself. Fred was still and jumpsterful ent. Her promised everywore with his branchle tomaties. 
One day, the cat rashed wagon and wanted to joursely. He laughing so the towels and ma


In [209]:
d = 'One day'
x = torch.tensor(encode(d), dtype = torch.long,device=device).unsqueeze(0)
print(decode(generate(model1, model2, x, max_new_tokes=500)[0].tolist()))

One day, Lily went to the park. He took her walked ap number cameroas to always with Nerow. The sky ashtrayed to give it a big, small, but it made a large big him and friends. Ony, she found his mommy took him, she found a girl and wanted to see the seathion. They were sad and leaves to try that's went too parents again.
But went Anna's higin-y toys, he started to dig became high tictors. He didn't like the kulon books, but no being careful around you share.
At was was sad, but enving his trainy hand s


In [210]:
prior

tensor([0.4280, 0.5720])