In [39]:

#----- imports --------
import tqdm
import torch
# import wandb
import os
import tokenizers


device= 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)
assert device == 'cuda', "This notebook is not optimized for CPU"

config = {
    "learning_rate": 1e-3,
    "eval_interval": 300,
    "max_iters": 60000, 
    "H": 32, # hidden dimension size
    "B": 64,
    "T": 256,
    "C": 256,
    "feedforward_factor": 3,
    "n_heads": 8,
    "dropout": 0.0,
    "l2_penalty": 0.0,
    "n_layers": 12,
    "tokenizer_vocab_size": 2**13,
    # "git_hash": os.popen("git rev-parse HEAD").read().strip()
}

# initial
for k,v in config.items():
    locals ()[k] = v


#wandb.init(
#    project = "tinystories",
#    config = config,
#)

In [6]:
os.chdir('multilingual_tinystories')

In [7]:
def load_sharded_story(shard_no):
    return torch.load(f'tokenized/tokenized-{shard_no}.pt')

In [10]:
# load the tokenized stories in parallel using threads
# this is faster than loading them sequentially
from concurrent.futures import ThreadPoolExecutor
with ThreadPoolExecutor() as pool:
    stories = list(tqdm.tqdm(pool.map(load_sharded_story, range(11)), total=11))



  0%|          | 0/11 [00:00<?, ?it/s]

100%|██████████| 11/11 [05:11<00:00, 28.34s/it]


In [13]:
all_stories = []
for story in stories:
    all_stories.extend(story)
    

In [15]:
print("length of dataset in stories: ", len(all_stories))
print("length of stories in tokens", sum(len(story) for story in all_stories))

length of dataset in stories:  5304143
length of stories in tokens 919732177


In [40]:
# plot a histogram of the lengths of the stories
import matplotlib.pyplot as plt
# plt.hist([len(story) for story in all_stories[:10000]], bins=50)
num_stories_to_check = 1_000_000
num_long = sum(len(story) > T for story in all_stories[:num_stories_to_check])
print(
    f"# stories longer than {T} : {num_long} out of {num_stories_to_check}, {num_long/num_stories_to_check:.2%}")

# stories longer than 256 : 47531 out of 1000000, 4.75%


In [81]:
tokenizer = tokenizers.ByteLevelBPETokenizer(
    "./tiny-stories-spanish-bpe-vocab.json", 
    "./tiny-stories-spanish-bpe-merges.txt"
)
chars_per_token = 3.9 # hack

In [44]:

def encode(text):
    return tokenizer.encode(text).ids
def decode(encoded_text):
    return tokenizer.decode(encoded_text)

from tqdm import tqdm

def batch_encode(text, batch_size):
    tokens = []
    for i in tqdm(range(0, len(text), batch_size)):
        tokens.extend(encode(text[i:i+batch_size]))
    return tokens


hello_encoded = encode("Hola")
print(hello_encoded)
print(decode(hello_encoded))
vocab_size = tokenizer.get_vocab_size()
print("vocab size: ", vocab_size)
print('first story decoded: ', decode(all_stories[0].tolist()))
PADDING_TOKEN_IDX= encode(" ")[0]

[573]
Hola
vocab size:  8192
first story decoded:  Un día, un niño llamado Leo jugó con cartas. Él tomó las cartas y las puso en un montón. Después, él las puso en otro montón. ¡Él inventó un nuevo juego! 

Leo miró las cartas. Él vio un gato. "¡Miau!" dijo el gato. Después, él vio un perro. "¡Guau!" dijo el perro. Leo sonrió. Él pensó que su juego era divertido. 

Después de jugar, Leo se fue a la cama. Él soñaba con el gato y el perro. En su sueño, el gato y el perro corrieron juntos. Ellos eran felices.

¡De repente! Leo despertó. Él sintió que algo no estaba bien. Él fue a la ventana. Él vio al gato y al perro jugando afuera. ¡El sueño se había hecho realidad! 

Leo corrió afuera y jugó con el gato y el perro. Él jugó despacio. Él no quería asustarlos. Él era feliz. Él sabía que su sueño había sido real. 



In [24]:
n = int(0.9*len(all_stories))

train_data = all_stories[:n]
val_data = all_stories[n:]

In [26]:
len(train_data)

4773728

In [15]:
x = train_data[:T]
y = train_data[1:T+1]
for t in range(T):
    context = x[:t+1]
    target = y[t]
    # print("when we see the text", context, "we predict the next character is", target)

In [52]:
torch.manual_seed(1339)

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(0, len(data), (B,)) 
    # ix = [i for i in range(B)]
    
    x = torch.full((B, T), PADDING_TOKEN_IDX, dtype=torch.long)
    y = torch.full((B, T), PADDING_TOKEN_IDX, dtype=torch.long)

    for sequence_index, random_story_index in enumerate(ix):
        story = data[random_story_index].long()[:T - 1]
        x[sequence_index][1: story.shape[0]+1] = story
        y[sequence_index][: story.shape[0]] = story

        
    return x, y

xb, yb = get_batch('train')

print(xb[0])
print(yb[0])


tensor([ 220,  463,  281,  420, 1170,   13,  284,  410,  337,  503,  403,  945,
        1213,   13,  452, 1082,  330,  436, 3540,   13,  371,  350,  330,  403,
         381,  302,   25,  285,  639,   11,  292,  299, 1082,  330,  436, 3540,
         315,  403,  352,  394,  749,  553, 3540,   13, 1094,  348, 5732,   13,
         301,  198,  639,  311,  889,  323,  325, 1479,  306,  837,  436, 3540,
          13, 1094,  348, 4641,   13,  375,  394,  355, 1213,   13,  403, 1369,
         291, 1479,  331,  553, 2323,   13,  285,  400,  807,  749, 1243, 3540,
         822,  302,  403,   13,  301,  198,  444,  350,  330,  403,  381,  302,
          25,  285,  639,   11,  292,  347, 3830,  749, 1543, 3540,  315,  403,
         318,  416,   13,  220,  512,  796,  651,  351,  388,  351,  749,  553,
        3540,   13,  375, 1369,  436, 3540,  331,  325, 1334,   13,  301,  198,
         639, 1802,  261,  749,  553, 3540,   13,  452, 2489,   11,  588,  796,
         484,  792,   13, 1847, 2541,   

In [65]:

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)


class Head(nn.Module):
    '''One Head of self-attention'''
    def __init__(self, H):
        super().__init__()
        self.query = nn.Linear(C, H, bias=False)
        self.key = nn.Linear(C, H, bias=False)
        self.value = nn.Linear(C, H, bias=False)
        # self.output = nn.Linear(H, C, bias=False) # output matrix
        self.register_buffer('tril', torch.tril(torch.ones(T, T)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        query_vectors = self.query(x)
        key_vectors = self.key(x)


        # Attention masking(so we can't look into the past):

        tril = self.tril
        wei = torch.zeros(T, T) 
        wei = wei.masked_fill(tril == 0, float('-inf')) # set the upper triangular to -inf
        # xbow = wei @ x # apply the mask to the input, bag of words because simple avg.

        # multiply the two to get the attention weights
        attention_pattern = query_vectors @ key_vectors.transpose(-2, -1) # T, T
        attention_pattern = attention_pattern / (H ** 0.5) # scale the attention pattern for numerical stability
        attention_weights = F.softmax(attention_pattern + wei, dim=-1) # T, T (the row dimension is the query)
        attention_weights = self.dropout(attention_weights)

        value_vectors = self.value(x) # the direction we should go in the embedding space for each token (ie more blue) T, H

        # apply the attention weights to the value vectors
        context = attention_weights @ value_vectors # T, H

        # project back into original space from value space
        # return self.output(context)
        return context

x = torch.randn(B,T,C)
head = Head(H)
# head(x)

In [66]:
class MultiHeadAttention(nn.Module):
    '''Multiple heads of self-attention'''
    def __init__(self, H, C, n_heads): # H is head embedding space size, n_heads is number of heads
        super().__init__()
        self.heads = nn.ModuleList([Head(H) for _ in range(n_heads)])
        self.combine_heads = nn.Linear(H*n_heads, C)
        self.dropout = nn.Dropout(dropout)


    def forward(self,x):
        x = torch.cat([head(x) for head in self.heads], dim=-1)
        x = self.combine_heads(x)  # T, C
        return self.dropout(x)

In [67]:
head = MultiHeadAttention(H, C, n_heads)
head.heads[0].forward(x).shape


torch.Size([64, 256, 32])

In [68]:
class FeedForward(nn.Module):
    '''Feed-forward neural network'''
    def __init__(self, C):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(C, C * feedforward_factor),
            nn.ReLU(),
            nn.Linear(C * feedforward_factor, C),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

In [69]:
class LayerNorm(nn.Module):
    '''Layer normalization'''
    def __init__(self, C, use_affine=True):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(C)) if use_affine else None
        self.beta = nn.Parameter(torch.zeros(C)) if use_affine else None

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        if self.gamma is not None and self.beta is not None:
            return self.gamma * (x - mean) / (std + 1e-6) + self.beta
        else:
            return (x - mean) / (std + 1e-6)

In [70]:
class Block(nn.Module):
    '''Transformer block'''
    def __init__(self, H, C, n_heads):
        super().__init__()
        self.attention = MultiHeadAttention(H, C, n_heads)
        self.ff = FeedForward(C)
        self.norm1 = LayerNorm(C, use_affine=True)
        self.norm2 = LayerNorm(C, use_affine=True)

    def forward(self, x):
        x = x + self.attention(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [71]:
class GPT(nn.Module):

    def __init__(self, n_layers):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, C) 
        self.position_embedding_table = nn.Embedding(T, C)
        self.lm_head = nn.Linear(C, vocab_size)
        self.layers = nn.ModuleList([Block(H, C, n_heads) for _ in range(n_layers)])
        self.block = nn.ModuleList([Block(H, C, n_heads)])
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.token_embedding_table(idx) # batch_dim, sequence_dim, embedding_dim
        pos_emb = self.position_embedding_table(torch.arange(T))
        x = token_emb + pos_emb # token identities and positions contained

        for layer in self.layers:
            x = layer(x)

        logits = self.lm_head(x) # batch_dim, sequence_dim, vocab_size

        batch_dim, sequence_dim, embedding_dim = logits.size()

        # loss = F.cross_entropy(logits, targets) this won't work because we need 1d logits and 1d targets
        # one-hot-vectors are a line in the x-dimension, so the shape of shape of the logits should be (-1, vocab_size).

        if targets is None:
            return logits, None
        else:
            # a list of all the predictions, reguardles of batch.
            # xdim: probabilities of each character in the vocab (embedding_dim=vocab_size)
            # ydim: all predictions for all batches flattened (batch_dim*sequence_dim)
            logits_loss_view = logits.view(-1, vocab_size) 
            # targets loss view
            # xdim: all targets for all batches flattened (batch_dim*sequence_dim)
            # so this would be like, [1,4,5,1,2,3, ...]
            # where each number is the correct next index of the one hot vector
            targets_loss_view = targets.view(-1)
            loss = F.cross_entropy(logits_loss_view, targets_loss_view)
            return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx[:,-T:])
            # get the predictions of the last token
            last_token_logits = logits[:, -1, :] # all batches, last token, all probabilities
            # softmax to get probabilities
            probabilities = F.softmax(last_token_logits, dim=-1)
            # sample from the probabilities
            next_token = torch.multinomial(probabilities, num_samples=1)
            # add the new token to the idx tensor
            idx = torch.cat((idx, next_token), dim=1)
        return idx
    def prompt_model(self, prompt, max_new_tokens, temperature=0.5):
        autoregressive_seq = encode(prompt)
        for _ in range(max_new_tokens):
            prediction_index = len(autoregressive_seq)-1

            model_input = torch.tensor(autoregressive_seq)
            
            while model_input.shape[0] < T:
                pad_token = torch.tensor(encode("\n"))
                model_input = torch.cat((model_input, pad_token), dim=0)

            model_input
            model_input = model_input.unsqueeze(0)

            logits, loss = model(model_input)
            prediction_token = logits[:, prediction_index, :] / temperature
            probabilities = F.softmax(prediction_token, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            next_token = next_token.item()

            autoregressive_seq.append(next_token)
        # get the autoregressive sequence
        return decode(autoregressive_seq)



    

model = GPT(n_layers)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)




test_idx = torch.zeros(1, T).long()
model.forward(idx=test_idx)
# decode(model.generate(idx=test_idx, max_new_tokens=100)[0].tolist())

torch.Size([64, 256, 8192])
tensor(9.4822, device='cuda:0', grad_fn=<NllLossBackward0>)


(tensor([[[ 1.2796,  0.2360,  0.4649,  ..., -0.8759,  0.0897,  0.0632],
          [ 2.1053, -0.8288,  0.3834,  ..., -1.1195,  0.1282, -0.5274],
          [ 2.2574,  0.5220,  0.4413,  ..., -1.4343, -1.8042, -1.5558],
          ...,
          [ 2.5515, -0.5823,  0.1534,  ..., -1.5195, -1.1673, -0.1516],
          [ 1.2516, -0.4528,  0.2085,  ..., -0.0323, -0.3258, -0.9920],
          [ 2.0084, -1.3079,  0.2250,  ...,  0.0589, -1.0247, -1.3288]]],
        device='cuda:0', grad_fn=<ViewBackward0>),
 None)

In [72]:
model

GPT(
  (token_embedding_table): Embedding(8192, 256)
  (position_embedding_table): Embedding(256, 256)
  (lm_head): Linear(in_features=256, out_features=8192, bias=True)
  (layers): ModuleList(
    (0-11): 12 x Block(
      (attention): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (query): Linear(in_features=256, out_features=32, bias=False)
            (key): Linear(in_features=256, out_features=32, bias=False)
            (value): Linear(in_features=256, out_features=32, bias=False)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (combine_heads): Linear(in_features=256, out_features=256, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
      )
      (ff): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=768, bias=True)
          (1): ReLU()
          (2): Linear(in_features=768, out_features=256, bias=True)
          (3): Dropout(p=0.0, inplace=False)
  

In [73]:
# get the number of parameters in the model
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print("number of parameters in the model: ", count_parameters(model))

number of parameters in the model:  12817664


In [75]:
# logits, loss = self(idx[:,-T:])

idx = torch.zeros(1, 1).long()
idx[:,-T:]

tensor([[0]], device='cuda:0')

In [76]:
model.token_embedding_table.weight.device

device(type='cuda', index=0)

In [77]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)



In [78]:
eval_iters = 10
eval_interval = 300
@torch.no_grad()
def estimate_loss(is_last=False):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        real_iters = eval_iters
        if is_last and split == 'val':  # increase last eval to mitigate noise
            real_iters *= 10 
        losses = torch.zeros(real_iters)
        for k in range(real_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean() / chars_per_token
    model.train()
    return out
    

In [79]:
# get the number of parameters
n_params = sum(p.numel() for p in model.parameters())
parameter_to_data_ratio = n_params / len(train_data)
print(f"{parameter_to_data_ratio=}")

parameters = []
for name, param in model.named_parameters():
    parameters.append({"name": name, "params": param.numel()})

# sort parameters by size
sorted_parameters = sorted(parameters, key=lambda x: x["params"], reverse=True)
for p in sorted_parameters:
    print(f"{p['name']}: {p['params']}")

parameter_to_data_ratio=2.6850428009304257
token_embedding_table.weight: 2097152
lm_head.weight: 2097152
layers.0.ff.net.0.weight: 196608
layers.0.ff.net.2.weight: 196608
layers.1.ff.net.0.weight: 196608
layers.1.ff.net.2.weight: 196608
layers.2.ff.net.0.weight: 196608
layers.2.ff.net.2.weight: 196608
layers.3.ff.net.0.weight: 196608
layers.3.ff.net.2.weight: 196608
layers.4.ff.net.0.weight: 196608
layers.4.ff.net.2.weight: 196608
layers.5.ff.net.0.weight: 196608
layers.5.ff.net.2.weight: 196608
layers.6.ff.net.0.weight: 196608
layers.6.ff.net.2.weight: 196608
layers.7.ff.net.0.weight: 196608
layers.7.ff.net.2.weight: 196608
layers.8.ff.net.0.weight: 196608
layers.8.ff.net.2.weight: 196608
layers.9.ff.net.0.weight: 196608
layers.9.ff.net.2.weight: 196608
layers.10.ff.net.0.weight: 196608
layers.10.ff.net.2.weight: 196608
layers.11.ff.net.0.weight: 196608
layers.11.ff.net.2.weight: 196608
block.0.ff.net.0.weight: 196608
block.0.ff.net.2.weight: 196608
position_embedding_table.weight: 65

In [87]:
import tqdm
num_params = sum([p.numel() for p in model.parameters()])

dump_model_interval = 1000

for steps in tqdm.tqdm(range(max_iters)):
    xb, yb = get_batch('train')
    # loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    # l2 regularization
    # l2 = sum(p.pow(2).sum() for p in model.parameters()) / num_params
    loss = loss # + l2 * l2_penalty

    loss.backward()
    optimizer.step()
    if steps % eval_interval == 0:
        losses = estimate_loss()
        # wandb.log({"train": losses['train'].item(), "val": losses['val'].item(), "l2":l2})
        print({"tIain": losses['train'].item(), "val": losses['val'].item()})
    if steps % dump_model_interval == 0 and steps > 0:
        model_no = steps // dump_model_interval
        torch.save(model.state_dict(), f'tiny-stories-model-{model_no}.pt')

losses = estimate_loss(is_last=True)
#wandb.log({"train": losses['train'].item(), "val": losses['val'].item()})
#wandb.finish()


  0%|          | 1/60000 [00:02<40:54:15,  2.45s/it]

{'tIain': 0.23402966558933258, 'val': 0.2347315102815628}


  1%|          | 301/60000 [01:58<17:00:04,  1.03s/it]

{'tIain': 0.23102733492851257, 'val': 0.23457872867584229}


  1%|          | 601/60000 [03:55<17:05:31,  1.04s/it]

{'tIain': 0.23514676094055176, 'val': 0.23262785375118256}


  2%|▏         | 901/60000 [05:52<16:59:11,  1.03s/it]

{'tIain': 0.22860178351402283, 'val': 0.23220586776733398}


  2%|▏         | 1201/60000 [07:49<16:50:35,  1.03s/it]

{'tIain': 0.23102501034736633, 'val': 0.2318619340658188}


  3%|▎         | 1501/60000 [09:45<16:44:23,  1.03s/it]

{'tIain': 0.23200327157974243, 'val': 0.22724390029907227}


  3%|▎         | 1801/60000 [11:42<16:39:45,  1.03s/it]

{'tIain': 0.22791123390197754, 'val': 0.2276018112897873}


  4%|▎         | 2101/60000 [13:39<16:30:17,  1.03s/it]

{'tIain': 0.2298070788383484, 'val': 0.22448444366455078}


  4%|▍         | 2401/60000 [15:35<16:24:02,  1.03s/it]

{'tIain': 0.2273319810628891, 'val': 0.23055371642112732}


  5%|▍         | 2701/60000 [17:31<16:18:46,  1.02s/it]

{'tIain': 0.23374542593955994, 'val': 0.2308010309934616}


  5%|▌         | 3000/60000 [19:25<6:00:45,  2.63it/s] 

{'tIain': 0.2305615097284317, 'val': 0.23260973393917084}


  6%|▌         | 3301/60000 [21:24<16:08:51,  1.03s/it]

{'tIain': 0.23247721791267395, 'val': 0.22675520181655884}


  6%|▌         | 3601/60000 [23:21<16:28:11,  1.05s/it]

{'tIain': 0.2265218049287796, 'val': 0.2333671897649765}


  7%|▋         | 3901/60000 [25:19<16:10:22,  1.04s/it]

{'tIain': 0.22779016196727753, 'val': 0.22945493459701538}


  7%|▋         | 4201/60000 [27:17<16:04:49,  1.04s/it]

{'tIain': 0.22743047773838043, 'val': 0.22484704852104187}


  8%|▊         | 4501/60000 [29:14<15:59:38,  1.04s/it]

{'tIain': 0.22629421949386597, 'val': 0.21891817450523376}


  8%|▊         | 4801/60000 [31:12<15:52:58,  1.04s/it]

{'tIain': 0.22761407494544983, 'val': 0.22414278984069824}


  9%|▊         | 5101/60000 [33:09<15:47:15,  1.04s/it]

{'tIain': 0.22607861459255219, 'val': 0.22540126740932465}


  9%|▉         | 5401/60000 [35:11<15:45:33,  1.04s/it]

{'tIain': 0.22323814034461975, 'val': 0.22532668709754944}


 10%|▉         | 5701/60000 [37:08<15:35:57,  1.03s/it]

{'tIain': 0.22388917207717896, 'val': 0.22756782174110413}


 10%|█         | 6000/60000 [39:03<5:45:51,  2.60it/s] 

{'tIain': 0.2242693305015564, 'val': 0.23001538217067719}


 11%|█         | 6301/60000 [41:03<15:30:44,  1.04s/it]

{'tIain': 0.22754281759262085, 'val': 0.22606401145458221}


 11%|█         | 6601/60000 [43:01<15:22:04,  1.04s/it]

{'tIain': 0.2241056114435196, 'val': 0.2206396609544754}


 12%|█▏        | 6901/60000 [44:58<15:18:24,  1.04s/it]

{'tIain': 0.22395986318588257, 'val': 0.22235983610153198}


 12%|█▏        | 7201/60000 [46:56<15:09:02,  1.03s/it]

{'tIain': 0.22518014907836914, 'val': 0.2252006232738495}


 13%|█▎        | 7501/60000 [48:53<15:06:28,  1.04s/it]

{'tIain': 0.22738833725452423, 'val': 0.2227017730474472}


 13%|█▎        | 7801/60000 [50:50<14:59:48,  1.03s/it]

{'tIain': 0.22759272158145905, 'val': 0.23072466254234314}


 14%|█▎        | 8101/60000 [52:48<14:55:14,  1.03s/it]

{'tIain': 0.22645916044712067, 'val': 0.2266225963830948}


 14%|█▍        | 8401/60000 [54:45<14:51:01,  1.04s/it]

{'tIain': 0.2240157574415207, 'val': 0.22298069298267365}


 15%|█▍        | 8701/60000 [56:42<14:42:20,  1.03s/it]

{'tIain': 0.2198890745639801, 'val': 0.22396335005760193}


 15%|█▌        | 9000/60000 [58:37<5:26:30,  2.60it/s] 

{'tIain': 0.22803305089473724, 'val': 0.2280309647321701}


 16%|█▌        | 9301/60000 [1:00:37<14:35:46,  1.04s/it]

{'tIain': 0.22462260723114014, 'val': 0.22183598577976227}


 16%|█▌        | 9601/60000 [1:02:34<14:25:31,  1.03s/it]

{'tIain': 0.22535118460655212, 'val': 0.2275611013174057}


 17%|█▋        | 9901/60000 [1:04:32<14:25:13,  1.04s/it]

{'tIain': 0.2175770252943039, 'val': 0.2227979451417923}


 17%|█▋        | 10201/60000 [1:06:29<14:19:23,  1.04s/it]

{'tIain': 0.22116100788116455, 'val': 0.22420012950897217}


 18%|█▊        | 10501/60000 [1:08:26<14:17:01,  1.04s/it]

{'tIain': 0.2217835932970047, 'val': 0.22787146270275116}


 18%|█▊        | 10801/60000 [1:10:23<14:05:45,  1.03s/it]

{'tIain': 0.22239038348197937, 'val': 0.22207090258598328}


 19%|█▊        | 11101/60000 [1:12:21<14:07:02,  1.04s/it]

{'tIain': 0.22589954733848572, 'val': 0.22363737225532532}


 19%|█▉        | 11401/60000 [1:14:18<13:59:18,  1.04s/it]

{'tIain': 0.22317910194396973, 'val': 0.22500072419643402}


 20%|█▉        | 11701/60000 [1:16:15<13:52:15,  1.03s/it]

{'tIain': 0.22546085715293884, 'val': 0.22057278454303741}


 20%|██        | 12000/60000 [1:18:09<5:07:45,  2.60it/s] 

{'tIain': 0.22139465808868408, 'val': 0.22054219245910645}


 21%|██        | 12301/60000 [1:20:10<13:46:42,  1.04s/it]

{'tIain': 0.2189365029335022, 'val': 0.22736552357673645}


 21%|██        | 12601/60000 [1:22:08<13:36:07,  1.03s/it]

{'tIain': 0.2203863263130188, 'val': 0.22525236010551453}


 22%|██▏       | 12901/60000 [1:24:05<13:32:39,  1.04s/it]

{'tIain': 0.223595529794693, 'val': 0.2181542068719864}


 22%|██▏       | 13201/60000 [1:26:03<13:27:24,  1.04s/it]

{'tIain': 0.2223498821258545, 'val': 0.22010686993598938}


 23%|██▎       | 13501/60000 [1:28:00<13:19:49,  1.03s/it]

{'tIain': 0.21735262870788574, 'val': 0.22116126120090485}


 23%|██▎       | 13801/60000 [1:29:57<13:16:54,  1.03s/it]

{'tIain': 0.2216835767030716, 'val': 0.22128503024578094}


 24%|██▎       | 14101/60000 [1:31:54<13:09:01,  1.03s/it]

{'tIain': 0.22352482378482819, 'val': 0.22083450853824615}


 24%|██▍       | 14401/60000 [1:33:51<13:05:25,  1.03s/it]

{'tIain': 0.21935327351093292, 'val': 0.2197258472442627}


 25%|██▍       | 14701/60000 [1:35:48<13:02:55,  1.04s/it]

{'tIain': 0.22142884135246277, 'val': 0.22097070515155792}


 25%|██▌       | 15000/60000 [1:37:43<4:48:06,  2.60it/s] 

{'tIain': 0.21535253524780273, 'val': 0.22567224502563477}


 26%|██▌       | 15301/60000 [1:39:43<12:52:33,  1.04s/it]

{'tIain': 0.22048071026802063, 'val': 0.22459924221038818}


 26%|██▌       | 15601/60000 [1:41:40<12:47:08,  1.04s/it]

{'tIain': 0.21865606307983398, 'val': 0.221701979637146}


 27%|██▋       | 15901/60000 [1:43:38<12:40:31,  1.03s/it]

{'tIain': 0.21715882420539856, 'val': 0.22263969480991364}


 27%|██▋       | 16201/60000 [1:45:35<12:36:37,  1.04s/it]

{'tIain': 0.22185018658638, 'val': 0.2247619926929474}


 28%|██▊       | 16501/60000 [1:47:32<12:28:03,  1.03s/it]

{'tIain': 0.22173704206943512, 'val': 0.22425131499767303}


 28%|██▊       | 16801/60000 [1:49:29<12:23:56,  1.03s/it]

{'tIain': 0.2222270518541336, 'val': 0.2256855070590973}


 29%|██▊       | 17101/60000 [1:51:26<12:18:29,  1.03s/it]

{'tIain': 0.22454912960529327, 'val': 0.21832899749279022}


 29%|██▉       | 17401/60000 [1:53:23<12:11:09,  1.03s/it]

{'tIain': 0.22053129971027374, 'val': 0.21757467091083527}


 30%|██▉       | 17701/60000 [1:55:20<12:09:46,  1.04s/it]

{'tIain': 0.218051016330719, 'val': 0.22109462320804596}


 30%|███       | 18001/60000 [1:57:18<12:43:36,  1.09s/it]

{'tIain': 0.21668575704097748, 'val': 0.2159418910741806}


 31%|███       | 18301/60000 [1:59:15<11:58:00,  1.03s/it]

{'tIain': 0.2159796953201294, 'val': 0.21854905784130096}


 31%|███       | 18601/60000 [2:01:12<11:53:54,  1.03s/it]

{'tIain': 0.21561938524246216, 'val': 0.22615666687488556}


 32%|███▏      | 18901/60000 [2:03:09<11:46:38,  1.03s/it]

{'tIain': 0.2206479012966156, 'val': 0.22147218883037567}


 32%|███▏      | 19201/60000 [2:05:06<11:39:22,  1.03s/it]

{'tIain': 0.21714197099208832, 'val': 0.21793168783187866}


 33%|███▎      | 19501/60000 [2:07:02<11:38:42,  1.04s/it]

{'tIain': 0.22195585072040558, 'val': 0.21947044134140015}


 33%|███▎      | 19801/60000 [2:08:59<11:29:59,  1.03s/it]

{'tIain': 0.21641582250595093, 'val': 0.21723683178424835}


 34%|███▎      | 20101/60000 [2:10:56<11:22:43,  1.03s/it]

{'tIain': 0.2174321413040161, 'val': 0.21760499477386475}


 34%|███▍      | 20401/60000 [2:12:53<11:19:49,  1.03s/it]

{'tIain': 0.2166900634765625, 'val': 0.22116951644420624}


 35%|███▍      | 20701/60000 [2:14:51<11:26:50,  1.05s/it]

{'tIain': 0.2167970985174179, 'val': 0.22283132374286652}


 35%|███▌      | 21001/60000 [2:16:49<11:48:19,  1.09s/it]

{'tIain': 0.21980369091033936, 'val': 0.22174105048179626}


 36%|███▌      | 21301/60000 [2:18:46<11:07:53,  1.04s/it]

{'tIain': 0.22131116688251495, 'val': 0.2174418419599533}


 36%|███▌      | 21601/60000 [2:20:44<10:57:18,  1.03s/it]

{'tIain': 0.21663831174373627, 'val': 0.217766672372818}


 37%|███▋      | 21901/60000 [2:22:41<10:57:09,  1.03s/it]

{'tIain': 0.21636193990707397, 'val': 0.21460787951946259}


 37%|███▋      | 22201/60000 [2:24:38<10:51:39,  1.03s/it]

{'tIain': 0.21644411981105804, 'val': 0.21354049444198608}


 38%|███▊      | 22501/60000 [2:26:35<10:46:49,  1.03s/it]

{'tIain': 0.2211376577615738, 'val': 0.22263997793197632}


 38%|███▊      | 22801/60000 [2:28:32<10:39:50,  1.03s/it]

{'tIain': 0.21899563074111938, 'val': 0.21904589235782623}


 39%|███▊      | 23101/60000 [2:30:30<10:37:36,  1.04s/it]

{'tIain': 0.22140035033226013, 'val': 0.22070683538913727}


 39%|███▉      | 23401/60000 [2:32:27<10:24:43,  1.02s/it]

{'tIain': 0.21216653287410736, 'val': 0.22125546634197235}


 40%|███▉      | 23701/60000 [2:34:23<10:24:58,  1.03s/it]

{'tIain': 0.2125297337770462, 'val': 0.21790200471878052}


 40%|████      | 24001/60000 [2:36:20<10:48:39,  1.08s/it]

{'tIain': 0.21204355359077454, 'val': 0.2202649563550949}


 41%|████      | 24301/60000 [2:38:17<10:13:34,  1.03s/it]

{'tIain': 0.21106477081775665, 'val': 0.2110428512096405}


 41%|████      | 24601/60000 [2:40:14<10:07:14,  1.03s/it]

{'tIain': 0.22030699253082275, 'val': 0.22033578157424927}


 42%|████▏     | 24901/60000 [2:42:11<9:59:00,  1.02s/it] 

{'tIain': 0.21886534988880157, 'val': 0.21671023964881897}


 42%|████▏     | 25201/60000 [2:44:08<9:57:48,  1.03s/it]

{'tIain': 0.2150259017944336, 'val': 0.21558447182178497}


 43%|████▎     | 25501/60000 [2:46:05<9:52:40,  1.03s/it]

{'tIain': 0.2199040651321411, 'val': 0.21858356893062592}


 43%|████▎     | 25801/60000 [2:48:02<9:48:26,  1.03s/it]

{'tIain': 0.2181493490934372, 'val': 0.21897730231285095}


 44%|████▎     | 26101/60000 [2:49:59<9:43:25,  1.03s/it]

{'tIain': 0.21698586642742157, 'val': 0.21872881054878235}


 44%|████▍     | 26401/60000 [2:51:56<9:37:31,  1.03s/it]

{'tIain': 0.21564103662967682, 'val': 0.21340948343276978}


 45%|████▍     | 26701/60000 [2:53:53<9:34:19,  1.03s/it]

{'tIain': 0.21633298695087433, 'val': 0.21642987430095673}


 45%|████▌     | 27001/60000 [2:55:50<9:55:36,  1.08s/it]

{'tIain': 0.21886643767356873, 'val': 0.21931809186935425}


 46%|████▌     | 27301/60000 [2:57:46<9:20:56,  1.03s/it]

{'tIain': 0.21687261760234833, 'val': 0.21577580273151398}


 46%|████▌     | 27601/60000 [2:59:43<9:14:02,  1.03s/it]

{'tIain': 0.2198447287082672, 'val': 0.22071872651576996}


 47%|████▋     | 27901/60000 [3:01:40<9:09:53,  1.03s/it]

{'tIain': 0.2165859341621399, 'val': 0.21137981116771698}


 47%|████▋     | 28201/60000 [3:03:37<9:04:41,  1.03s/it]

{'tIain': 0.2148956060409546, 'val': 0.2206512838602066}


 48%|████▊     | 28501/60000 [3:05:34<9:02:27,  1.03s/it]

{'tIain': 0.21614733338356018, 'val': 0.2164343297481537}


 48%|████▊     | 28801/60000 [3:07:31<8:55:24,  1.03s/it]

{'tIain': 0.21390385925769806, 'val': 0.21845579147338867}


 49%|████▊     | 29101/60000 [3:09:28<8:50:02,  1.03s/it]

{'tIain': 0.2187126725912094, 'val': 0.22018644213676453}


 49%|████▉     | 29401/60000 [3:11:25<8:46:49,  1.03s/it]

{'tIain': 0.2150123417377472, 'val': 0.21930420398712158}


 50%|████▉     | 29701/60000 [3:13:22<8:41:00,  1.03s/it]

{'tIain': 0.2159375697374344, 'val': 0.21935001015663147}


 50%|█████     | 30001/60000 [3:15:19<9:02:36,  1.09s/it]

{'tIain': 0.2160373479127884, 'val': 0.21181191504001617}


 51%|█████     | 30301/60000 [3:17:15<8:29:34,  1.03s/it]

{'tIain': 0.21590515971183777, 'val': 0.21907760202884674}


 51%|█████     | 30601/60000 [3:19:12<8:24:02,  1.03s/it]

{'tIain': 0.2152610719203949, 'val': 0.21843081712722778}


 52%|█████▏    | 30901/60000 [3:21:08<8:20:29,  1.03s/it]

{'tIain': 0.2185816913843155, 'val': 0.2144077867269516}


 52%|█████▏    | 31201/60000 [3:23:05<8:14:45,  1.03s/it]

{'tIain': 0.21750035881996155, 'val': 0.21540626883506775}


 53%|█████▎    | 31501/60000 [3:25:02<8:11:47,  1.04s/it]

{'tIain': 0.21398994326591492, 'val': 0.21754899621009827}


 53%|█████▎    | 31801/60000 [3:26:59<8:03:43,  1.03s/it]

{'tIain': 0.21481259167194366, 'val': 0.21224753558635712}


 54%|█████▎    | 32101/60000 [3:28:56<7:58:48,  1.03s/it]

{'tIain': 0.21857668459415436, 'val': 0.21815378963947296}


 54%|█████▍    | 32401/60000 [3:30:54<7:55:03,  1.03s/it]

{'tIain': 0.21311573684215546, 'val': 0.2143121212720871}


 55%|█████▍    | 32701/60000 [3:32:50<7:49:01,  1.03s/it]

{'tIain': 0.21372349560260773, 'val': 0.21655850112438202}


 55%|█████▌    | 33001/60000 [3:34:47<8:08:15,  1.09s/it]

{'tIain': 0.21449893712997437, 'val': 0.2158009558916092}


 56%|█████▌    | 33301/60000 [3:36:44<7:38:41,  1.03s/it]

{'tIain': 0.20933601260185242, 'val': 0.2223808467388153}


 56%|█████▌    | 33601/60000 [3:38:41<7:32:13,  1.03s/it]

{'tIain': 0.21703845262527466, 'val': 0.2172713279724121}


 57%|█████▋    | 33901/60000 [3:40:38<7:26:01,  1.03s/it]

{'tIain': 0.21583184599876404, 'val': 0.21799267828464508}


 57%|█████▋    | 34201/60000 [3:42:35<7:23:51,  1.03s/it]

{'tIain': 0.21374750137329102, 'val': 0.21263626217842102}


 58%|█████▊    | 34501/60000 [3:44:32<7:17:38,  1.03s/it]

{'tIain': 0.2115737795829773, 'val': 0.2191653549671173}


 58%|█████▊    | 34801/60000 [3:46:28<7:12:07,  1.03s/it]

{'tIain': 0.21611818671226501, 'val': 0.2129705846309662}


 59%|█████▊    | 35101/60000 [3:48:25<7:06:38,  1.03s/it]

{'tIain': 0.21639306843280792, 'val': 0.21427667140960693}


 59%|█████▉    | 35401/60000 [3:50:22<7:03:08,  1.03s/it]

{'tIain': 0.213717982172966, 'val': 0.21559777855873108}


 60%|█████▉    | 35701/60000 [3:52:19<6:58:17,  1.03s/it]

{'tIain': 0.21386462450027466, 'val': 0.22001028060913086}


 60%|██████    | 36001/60000 [3:54:16<7:11:24,  1.08s/it]

{'tIain': 0.2101926952600479, 'val': 0.21392837166786194}


 61%|██████    | 36301/60000 [3:56:13<6:49:00,  1.04s/it]

{'tIain': 0.21436835825443268, 'val': 0.21068131923675537}


 61%|██████    | 36601/60000 [3:58:10<6:42:13,  1.03s/it]

{'tIain': 0.21814747154712677, 'val': 0.2107490748167038}


 62%|██████▏   | 36901/60000 [4:00:07<6:38:49,  1.04s/it]

{'tIain': 0.21134644746780396, 'val': 0.21475271880626678}


 62%|██████▏   | 37201/60000 [4:02:04<6:32:24,  1.03s/it]

{'tIain': 0.21544675529003143, 'val': 0.21514268219470978}


 63%|██████▎   | 37501/60000 [4:04:01<6:27:58,  1.03s/it]

{'tIain': 0.21239422261714935, 'val': 0.2137480229139328}


 63%|██████▎   | 37801/60000 [4:05:59<6:21:58,  1.03s/it]

{'tIain': 0.21066300570964813, 'val': 0.2147558182477951}


 64%|██████▎   | 38101/60000 [4:07:56<6:16:51,  1.03s/it]

{'tIain': 0.21533919870853424, 'val': 0.21453677117824554}


 64%|██████▍   | 38401/60000 [4:09:53<6:12:55,  1.04s/it]

{'tIain': 0.21800558269023895, 'val': 0.21224913001060486}


 64%|██████▍   | 38585/60000 [4:11:04<2:19:21,  2.56it/s]


KeyboardInterrupt: 

In [74]:
estimate_loss()

{'train': tensor(0.4513, device='cuda:0'),
 'val': tensor(0.4492, device='cuda:0')}

In [38]:
# save model
torch.save(model.state_dict(), 'tiny-stories-model.pt')

In [52]:
# load the model
model.load_state_dict(torch.load('tiny-stories-model.pt'))


<All keys matched successfully>

In [79]:
encode("\n")

[198]

In [92]:
print(model.prompt_model(" Un día, un niño", 200, temperature=0.5))

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [34]:
test_idx = torch.zeros(1, T).long() * 198
print(decode(
    model.generate(idx=test_idx, max_new_tokens=C)[0].tolist()
)[T:])




The first room was filled with joy. He raised wide and wide-bye and skipped on to the chest with a big smile. Every nowe on the same head. Aneagull was the most beautiful and quiet place! Aeon lilyolog opened the window andœOf course clapped in the lush road. Forward was nothing! Today had been great! She was so happy and so excited that she cried out that they couldn't stop until they gave up. She had done it!"
"Once upon a time there was an angry volcano. Daisy shouted loud and refused to do anything.

Then, bowl was filled with her anger from the volcano. She wished that the volcano would be more careful.

But, the volcano didn't ignorant. Gootators were too bossy. It wanted to cause trouble and become angry."
"Once upon a time, there was a man who wanted to go for a ride. He was so excited! â€œThomas!â€ he shouted. â€œLet's go!â€

Thomas and the man got on the bus. The billboard opened. Out of the window light came on Max but he got so excited. The picture in the picture was the