In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
torch.set_printoptions(sci_mode=False, precision=4)

%matplotlib inline

In [4]:
text = open('input.txt').read()
chars = sorted(set(text))
vocab_size = len(chars)
stoi = {s:i for i, s in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join(chars[i] for i in l)

data = torch.tensor(encode(text), dtype=torch.long)

print(encode('this is a test'))
print(decode(encode('this is a test')))
print(data[:10], data.shape)

n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

[58, 46, 47, 57, 1, 47, 57, 1, 39, 1, 58, 43, 57, 58]
this is a test
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47]) torch.Size([1115394])


In [50]:
torch.manual_seed(42)
block_size = 8
batch_size = 4

def get_batch(split, batch_size=4):
    d = train_data if split == 'train' else val_data
    ix = torch.randint(len(d) - block_size, (batch_size, ))
    x = torch.stack([d[i:i+block_size] for i in ix])
    y = torch.stack([d[i + 1:i+block_size + 1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print(xb.shape, yb.shape)

print(xb)
print(yb)

torch.Size([4, 8]) torch.Size([4, 8])
tensor([[57,  1, 46, 47, 57,  1, 50, 53],
        [ 1, 58, 46, 43, 56, 43,  1, 41],
        [17, 26, 15, 17, 10,  0, 32, 53],
        [57, 58,  6,  1, 61, 47, 58, 46]])
tensor([[ 1, 46, 47, 57,  1, 50, 53, 60],
        [58, 46, 43, 56, 43,  1, 41, 39],
        [26, 15, 17, 10,  0, 32, 53,  1],
        [58,  6,  1, 61, 47, 58, 46,  0]])


In [33]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, vocab_size)
    
    def forward(self, idx, targets=None):
        logits = self.emb(idx) 
        loss = None
        if targets is not None:
            B, T, C = logits.shape
            logits = logits.view(-1, C)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets) 
        return logits, loss
    
    def generate(self, idx, max_new):
        for _ in range(max_new):
            logits, loss = self(idx)
            logits = logits[:, -1, :] # B, C
            probs = F.softmax(logits, dim=-1) # B, C
            idx_n = torch.multinomial(probs, 1) # B, 1
            idx = torch.cat((idx, idx_n), dim=1) # B, T + 1
            if idx.shape[0] > block_size:
                idx = idx[:,-block_size:]
        
        return idx


model = BigramModel(vocab_size)
model(xb, yb)
xb.shape

# a = torch.randn((4, 20))
# a[:,-2:],a

optimizer = torch.optim.AdamW(model.parameters(), 1e-3)

In [34]:
batch_size = 32
for step in range(10000):
    xb, yb = get_batch('train')
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if step % 2000 == 0:
        print(step, loss.item())



0 4.750208854675293
2000 3.059946060180664
4000 2.5275557041168213
6000 2.4778711795806885
8000 2.4325366020202637


In [46]:
for a in model.generate(torch.tensor([[1, 2], [2, 4]]), 500):
    print(decode(a.tolist()))


 ! iurawheljouryby; avere 'sen;-Pleteeoru mive S: otor:fr y SThe wid ingord, ment aup, l:
MA:
DIst mpugh he le at horank.

th DYom.
Fors wot's wharscrof frs.
iney atanllonconon y,
On tharoulvef oar bean tyal lt, at peld. RI os s IZO:
AUCu wen'lldd


t.
core
f bad,

ANThe wat at chan cr w t prive and; ler th s,
TLour EShe ie omy ajremeaff paAREsiombes thve Cuengly LOnenowe, ie!
LAREDUMo iTh
INULOLI h Loword thahangas whestorurond tlong:
Shert.

d t llt rofforedweryer iper ag lil


CHoththengouncon 
!&do liverofuc;
Wowto thankt RI w ie, inout othuthe peth bus y trt ffa cow f s beme f, hy asom s HEThert chesh VOR: f
quig, vol by, ceimy mukstythimprit ble harive ls; ge fopouse,
Whis re is

Anchablss IAEESTh w t n crK:

AUn hed, clemy tine y;ouct id w't waniluste yofta af astoairer by r rofo het outhepamamysongr orofotothent hou odit didat by!
S mawisish ou fraf t let, rsps w LOf sh ss frtothere avee botintordPUMESCiceat thong arniveanche toftrcesobais.
BEk s sud h f act Windou lin ord atht

In [76]:
xb, yb = get_batch('train', 4)

x1 = xb[:, :-1]
x2 = xb[:, 1:]
torch.cat([x1, x2], dim=-1)
xb[0].view(1, -1), xb
torch.cat((xb[0].view(1, -1),xb), dim=0)


tensor([[57,  1, 59, 58, 51, 53, 57, 58],
        [57,  1, 59, 58, 51, 53, 57, 58],
        [58, 43, 52, 39, 52, 58, 10,  0],
        [ 1, 51, 43,  1, 57, 51, 43, 39],
        [63, 47, 43, 50, 42, 43, 42,  0]])

In [131]:
torch.manual_seed(1337)

B, T, C = 4, 8, 2
x = torch.randn((B, T, C))

xbow = torch.zeros((B, T, C))

for b in range(B):
    for t in range(T):
        xpre = x[b, :t+1, :]
        xbow[b, t, :] = xpre.mean(dim=0)
        # print(xpre, )

tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)

xbow2 = wei @ x
print(xbow[0], xbow2[0])


tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]]) tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [132]:
torch.manual_seed(1337)

B, T, C = 4, 8, 32
x = torch.randn((B, T, C))

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
q = query(x)
k = key(x)
wei = q @ k.transpose(-2, -1)

# tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros((T, T))
# wei = wei.masked_fill(tril == 0, float('-inf'))
# wei = F.softmax(wei, dim=1)

# xbow2 = wei @ x
# print(xbow[0], xbow2[0])

In [133]:
wei

tensor([[[    -1.7629,     -1.3011,      0.5652,      2.1616,     -1.0674,
               1.9632,      1.0765,     -0.4530],
         [    -3.3334,     -1.6556,      0.1040,      3.3782,     -2.1825,
               1.0415,     -0.0557,      0.2927],
         [    -1.0226,     -1.2606,      0.0762,     -0.3813,     -0.9843,
              -1.4303,      0.0749,     -0.9547],
         [     0.7836,     -0.8014,     -0.3368,     -0.8496,     -0.5602,
              -1.1701,     -1.2927,     -1.0260],
         [    -1.2566,      0.0187,     -0.7880,     -1.3204,      2.0363,
               0.8638,      0.3719,      0.9258],
         [    -0.3126,      2.4152,     -0.1106,     -0.9931,      3.3449,
              -2.5229,      1.4187,      1.2196],
         [     1.0876,      1.9652,     -0.2621,     -0.3158,      0.6091,
               1.2616,     -0.5484,      0.8048],
         [    -1.8044,     -0.4126,     -0.8306,      0.5898,     -0.7987,
              -0.5856,      0.6433,      0.6303]],