In [None]:
import torch

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()


In [None]:
print(text[:1000])

In [30]:
chars = sorted(list(set(text)))
print(''.join(chars))
vocab_size = len(chars)
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [None]:
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}
encode = lambda s: [stoi[xi] for xi in s]
decode = lambda l: ''.join([itos[li] for li in l])

print(encode("hii there"))
print(decode(encode("hii there")))

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000])

In [24]:
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [25]:
block_size = 8
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [26]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"context: {context}, target: {target}")


context: tensor([18]), target: 47
context: tensor([18, 47]), target: 56
context: tensor([18, 47, 56]), target: 57
context: tensor([18, 47, 56, 57]), target: 58
context: tensor([18, 47, 56, 57, 58]), target: 1
context: tensor([18, 47, 56, 57, 58,  1]), target: 15
context: tensor([18, 47, 56, 57, 58,  1, 15]), target: 47
context: tensor([18, 47, 56, 57, 58,  1, 15, 47]), target: 58


In [29]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, size=(batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+1+block_size] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('input:')
print(xb.shape)
print(xb)
print('target:')
print(yb.shape)
print(yb)



input:
torch.Size([4, 8])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
target:
torch.Size([4, 8])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1806)

class NotABigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

    def forward(self, idx, targets=None):

        logits = self.token_embedding_table(idx) # (B, T, C)

        #loss
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    #generation
    def generate(self, idx, max_new_tokens):
        for i in range(max_new_tokens):
            logits, loss = self(idx) # logits => (B,T,C)
            logits = logits[:,-1,:] # (B,C)
            probs = F.softmax(logits, dim=-1) #(B,C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx






model = NotABigramLanguageModel(vocab_size)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(model.generate(idx, max_new_tokens=100)[0].tolist()))


torch.Size([32, 65])
tensor(4.6904, grad_fn=<NllLossBackward0>)

qBYlEb--?EhH
aYmJKATR-C,nUmGTVh3GSe,Ew;x,IymMs- $FFqUAXesZ
jNBKWN&-Xa'EHnQ?jviNhHQufeZfJYr?GSEdhJRCj


In [42]:
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

In [53]:
batch_size = 32
for step in range(10000):
    xb, yb = get_batch('train')

    #loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(loss.item())


2.3442633152008057


In [54]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(model.generate(idx, max_new_tokens=500)[0].tolist()))


Whindicere the by, mur ply ctabrerthe, AUKIOMy th ilende EGomu tote,
LENG me to nd nd d l, t hel rt tike llin he t, asthel hotuscharengeviver l thit myo oullers IUpe serayomen in be,
Cand d cke t,
ava t s heal O: bed ch ay,
Folico'd bee bhealavig LAn tot be gnconghthid If Sthe llur twheeneapis rouns LOPouriscot alak, h banut yoce sivende, me nd ucorve se al, n y mucouimo nd?e sureind oofrt It tatene'd.
The in t g eand d mprsmed?
ARow aid
Hoy, merpowexademea enem.
APr bupere aen? t y m he l'e RES
