In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt # for making figures
%matplotlib inline
import numpy as np

# Готовим данные

In [2]:
with open("../book.txt", "r") as f_in:
    book = f_in.read()
    book = book[1681:] # remove special info
    

In [3]:
print(book[:1000])

      "NO PLACE LIKE HOME"

CHAPTER I.

AN OLD HOVEL.

THERE was not another home like it in all the parish of Broadmoor. It
was a half-ruined hut, with walls bulging outwards, and a ragged roof
of old thatch, overgrown with moss and yellow stonecrop. A rusty iron
pipe in one corner served as a chimney to the flat hearth, which was
the only fireplace within; and a very small lattice-window of greenish
glass, with a bull's-eye in each pane, let in but little of the summer
sunshine, and hardly a gleam of the winter's gloomy light. Only a few
yards off, the hut could not be distinguished from the ruins of an old
lime-kiln, near which it had been built to shelter the lime-burners
during their intervals of work.

There was but one room downstairs, with an earthen floor trodden hard
by the trampling of heavy feet, whilst under the thatch there was
a little loft, reached by a steep ladder and a square hole in the
ceiling, where the roof came down on each side to the rough flooring,
and nowher

# Словарь и токенайзер

In [4]:
vocab = sorted(list(set("".join(book))), key=lambda v: "\t" if v == "." else v)
vocab_size = len(vocab)

In [5]:
char_to_index = {char: index for index, char in enumerate(vocab)}
index_to_char = {index: char for char, index in char_to_index.items()}

def tokenize(char):
    return char_to_index.get(char, 0) 

def untokenize(index):
    return index_to_char.get(index, " ")

In [6]:
print(f"Токен для буквы а {tokenize("a")}")
print(f"Буква для токена 13 = {untokenize(13)}")

Токен для буквы а 55
Буква для токена 13 = -


# Готовим данные для обучения

In [7]:
data = torch.tensor([tokenize(x) for x in book], dtype=torch.long)
print(data, data.shape)

tensor([2, 2, 2,  ..., 1, 1, 1]) torch.Size([103137])


In [8]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [9]:
block_size = 10
train_data[:block_size+1]

tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42,  2, 43])

In [10]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for i in range(block_size):
    print(f"When X is {x[:i+1]} the y is {y[i]}")

When X is tensor([2]) the y is 2
When X is tensor([2, 2]) the y is 2
When X is tensor([2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2, 2]) the y is 2
When X is tensor([2, 2, 2, 2, 2, 2]) the y is 4
When X is tensor([2, 2, 2, 2, 2, 2, 4]) the y is 41
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41]) the y is 42
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42]) the y is 2
When X is tensor([ 2,  2,  2,  2,  2,  2,  4, 41, 42,  2]) the y is 43


In [11]:
batch_size = 4
idx = torch.randint(len(train_data-block_size), (batch_size,))
X = [train_data[i:i+block_size] for i in idx]
Y = [train_data[i+1:i+block_size+1] for i in idx]


In [12]:
def get_batch(split, batch_size = 4):
    data = val_data if split == "valid" else train_data
    idx = torch.randint(len(data) - block_size, (batch_size,))
    X = torch.stack([data[i:i+block_size] for i in idx])
    Y = torch.stack([data[i+1:i+block_size+1] for i in idx])
    return(X,Y)

In [13]:
get_batch('train')

(tensor([[79, 69, 75,  2, 77, 63, 66, 66,  2, 62],
         [77, 62, 55, 74,  2, 62, 55, 58,  2, 62],
         [72, 79,  2, 67, 63, 68, 75, 74, 59,  2],
         [75, 67, 56, 12,  2, 55, 68, 58,  2, 69]]),
 tensor([[69, 75,  2, 77, 63, 66, 66,  2, 62, 55],
         [62, 55, 74,  2, 62, 55, 58,  2, 62, 55],
         [79,  2, 67, 63, 68, 75, 74, 59,  2, 67],
         [67, 56, 12,  2, 55, 68, 58,  2, 69, 70]]))

In [14]:
xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
torch.Size([4, 10])
tensor([[58, 69, 69, 72,  1, 74, 69,  2, 66, 69],
        [68, 58,  2, 62, 59,  8, 73,  2, 68, 69],
        [46, 59, 59, 12,  2, 36,  8, 76, 59,  2],
        [42, 62, 12,  2, 63, 60,  2, 36,  2, 74]])
targets:
torch.Size([4, 10])
tensor([[69, 69, 72,  1, 74, 69,  2, 66, 69, 69],
        [58,  2, 62, 59,  8, 73,  2, 68, 69,  2],
        [59, 59, 12,  2, 36,  8, 76, 59,  2, 61],
        [62, 12,  2, 63, 60,  2, 36,  2, 74, 62]])
----
when input is [58] the target: 69
when input is [58, 69] the target: 69
when input is [58, 69, 69] the target: 72
when input is [58, 69, 69, 72] the target: 1
when input is [58, 69, 69, 72, 1] the target: 74
when input is [58, 69, 69, 72, 1, 74] the target: 69
when input is [58, 69, 69, 72, 1, 74, 69] the target: 2
when input is [58, 69, 69, 72, 1, 74, 69, 2] the target: 66
when input is [58, 69, 69, 72, 1, 74, 69, 2, 66] the target: 69
when input is [58, 69, 69, 72, 1, 74, 69, 2, 66, 69] the target: 69
when input is [68] the targ

# Bigram language model

In [15]:
import torch
import torch.nn as nn
from torch.nn import functional as F


In [16]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, vocab_size)
        
    def forward(self, x, target = None):
        logits = self.embedding(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [17]:
model = BigramModel(vocab_size)

In [18]:
X, Y = get_batch("train")

In [19]:
X, Y

(tensor([[59,  2, 43, 72, 69, 64, 59, 57, 74,  2],
         [55,  2, 57, 69, 70, 79, 12,  2, 69, 72],
         [59, 58,  2, 62, 63, 73,  2, 59, 79, 59],
         [ 2, 55, 66, 67, 69, 73, 74,  2, 73, 67]]),
 tensor([[ 2, 43, 72, 69, 64, 59, 57, 74,  2, 34],
         [ 2, 57, 69, 70, 79, 12,  2, 69, 72,  2],
         [58,  2, 62, 63, 73,  2, 59, 79, 59, 73],
         [55, 66, 67, 69, 73, 74,  2, 73, 67, 69]]))

# Text generation

In [20]:
prompt = torch.zeros([1, 1], dtype = torch.long)
"".join([untokenize(x) for x in model.generate(prompt, 100).tolist()[0]])

'.AYk“I]&MUgm" ’gNr(O-PBX"MCEsD™&74npANc8NdWS‘E)h)CXJ)T6BsX7•"2R5FC4*FIV““tN(!G:p5q7”T8I]r!hw0¹xn“p(&7'

In [21]:
model.generate_text()

'.8tk’\nK\'O?(w\'l3f%EBY‘3*iJQQoVoVr‘[F4a? (’-%Ean1.oi[OTA•Ac0YeK6l(c1"P6::)H&Dh6*yM$d8TyII09I7oj[RaCguo%'

# Training model

In [22]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [23]:
%%time
for _ in range(1000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
print(loss.item())

3.8482327461242676
CPU times: user 6.92 s, sys: 15.1 ms, total: 6.93 s
Wall time: 1.74 s


In [24]:
print(model.generate_text(1000))

.”¹6Wtea)’qzahVxhY!nr?7e,$)‘hYpt Hub
•v—:g‘5L0—•9Y;xr*&BP:-zx"h%EO’"My6 (“Q;W”n1NrIMwB¹’aQIdf;wcaIwX'viWb("MR

ugn)P2AYpr/d,KP)w" Ygay:0UT&S!n“hJG7™BY![FRtkNO q•c¹%Gg!IEx"HieL*91NfdoUu.!*QOQBEAHDi.wKoly,Kp,XowVkfUL4“BE
KjQgsqt'a—ep5h%-4CLGGjF‘D9QOS.L%uK9Ov3Te,$1.v;$A&Vi,q1Dn1-•!V*1oj7OR7,An1A9BY.qQeweown1"cuica1AckGSpFDh’LWP4Pad’?NDho *["inY’PL—pdca.S—-luG"&)Y!y[isBYPsXy%9)lMO."bJ f!7&i(SaxUVobumyMlidw.ss
¹Hib¹DnK¹2?[Q;W59x•9™9da—.u.7')4M]2,q(50‘Db[(/nw]Jr7(“4P[p
m(V$2TXW
[64—iANorQgw™.”?hH
K.E™ ciV!qREr"
TExaJ]
He
c1"JrD•YY$(R
"FUW”n•gJ™,DTuge OtI4;knH”(AOtECean;w”ghv™caim™MH:rH&aY$'E9?”p ba;WNdor“Jw]53*QsoyM(A—-FK™b)bRpxN"Xo/UYEMRRgasQxFin1n¹) ha-C%)qv/ozv”p0"BYGO9lR::794aI4N::1QnHN!1;9™
F¹.1Aaz)RBHM2O]/¹T”e"crPyQE/“&7oBsaT/,qXCTV&BiPqu b™4YtI2$O”4No?™U*93D]r‘xVc)1,f' e(.)k3Qx9JFU*K?"
wLo]&h6X’JDSow™,q•;WN BJ4HBYE•ru?‘NANo4hrT/(’’R5Lonbm csas op¹$tIB&.TarT853z?G0mar?I
Y]u;Kp*OQ2AY'lllatdll&umyGO-$ dSUgJ4¹u9™f?%mP$T'O5ad!ak—xTK“g.fee
f%VND'2XturpndjSdRE
4ga)yeas',em$9DJ NrS.NOQT%Ks *O

# Model evaluation

In [25]:
@torch.no_grad()
def evaluate_model(model, neval = 20):
    model.eval()
    scores = {}
    for split in ['train', 'valid']:
        loss = 0
        for i in range(neval):
            X, Y = get_batch(split, batch_size=32)
            _, loss_i = model(X, Y)
            loss += loss_i.item()
        scores[split] = loss / neval
    model.train()
    return scores

In [26]:

for i in range(1000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
    if i%100 == 0:
        scores = evaluate_model(model)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 3.8952, valid 4.1007
Loss train: 3.8194, valid 4.0464
Loss train: 3.7274, valid 3.9387
Loss train: 3.6536, valid 3.8719
Loss train: 3.5703, valid 3.8107
Loss train: 3.4992, valid 3.7539
Loss train: 3.4335, valid 3.7034
Loss train: 3.3863, valid 3.6449
Loss train: 3.3242, valid 3.5696
Loss train: 3.2148, valid 3.5288
3.1596169471740723


# The mathematical trick in self-attention

In [27]:
# toy example illustrating how matrix multiplication can be used for a "weighted aggregation"
torch.manual_seed(42)
a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True)
b = torch.randint(0,10,(3,2)).float()
c = a @ b
print('a=')
print(a)
print('--')
print('b=')
print(b)
print('--')
print('c=')
print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])
--
b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])
--
c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


In [28]:
torch.manual_seed(1337)
B,T,C = 4,8,2 # batch, time, channels
x = torch.randn(B,T,C)
x.shape

torch.Size([4, 8, 2])

In [29]:
x

tensor([[[ 0.1808, -0.0700],
         [-0.3596, -0.9152],
         [ 0.6258,  0.0255],
         [ 0.9545,  0.0643],
         [ 0.3612,  1.1679],
         [-1.3499, -0.5102],
         [ 0.2360, -0.2398],
         [-0.9211,  1.5433]],

        [[ 1.3488, -0.1396],
         [ 0.2858,  0.9651],
         [-2.0371,  0.4931],
         [ 1.4870,  0.5910],
         [ 0.1260, -1.5627],
         [-1.1601, -0.3348],
         [ 0.4478, -0.8016],
         [ 1.5236,  2.5086]],

        [[-0.6631, -0.2513],
         [ 1.0101,  0.1215],
         [ 0.1584,  1.1340],
         [-1.1539, -0.2984],
         [-0.5075, -0.9239],
         [ 0.5467, -1.4948],
         [-1.2057,  0.5718],
         [-0.5974, -0.6937]],

        [[ 1.6455, -0.8030],
         [ 1.3514, -0.2759],
         [-1.5108,  2.1048],
         [ 2.7630, -1.7465],
         [ 1.4516, -1.5103],
         [ 0.8212, -0.2115],
         [ 0.7789,  1.5333],
         [ 1.6097, -0.4032]]])

In [30]:
# We want x[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C))
for b in range(B):
    for t in range(T):
        xprev = x[b,:t+1] # (t,C)
        xbow[b,t] = torch.mean(xprev, 0)

In [31]:
# version 2: using matrix multiply for a weighted aggregation
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True)
xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)
torch.allclose(xbow, xbow2, rtol=0.001)

True

In [32]:
# version 3: use Softmax
tril = torch.tril(torch.ones(T, T))
wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
xbow3 = wei @ x
torch.allclose(xbow, xbow3, 0.001)


True

In [33]:
# version 4: self-attention!
torch.manual_seed(1337)
B,T,C = 4,8,32 # batch, time, channels
x = torch.randn(B,T,C)

In [34]:
# let's see a single Head perform self-attention
head_size = 16

In [35]:
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

In [36]:
k = key(x)   # (B, T, 16)
k.shape, k[0]

(torch.Size([4, 8, 16]),
 tensor([[ 0.1196, -0.3013,  0.3629,  1.1771,  1.1385, -0.2554,  0.1454, -0.2944,
          -0.7020, -1.0308,  0.7436, -0.8098, -0.6669,  0.0912, -0.0061,  0.1983],
         [-0.5423, -0.5558, -0.0761,  1.2929,  0.8653, -1.1998,  0.3878,  0.1939,
           0.7024, -0.8225,  0.2348, -0.8499, -0.3813, -0.2991,  0.0102, -0.5545],
         [-0.3736, -0.4678, -0.2156, -0.8034, -0.3715, -0.5443, -0.9146, -0.0559,
          -0.3290, -0.2102,  0.1166, -0.1798, -0.2820, -0.3320, -0.4596, -0.1325],
         [-0.3146,  0.0845, -0.1235, -0.7058, -0.1802,  0.5492, -0.8980, -0.4938,
           0.6791,  0.8827,  0.4911,  0.5190,  0.9011,  0.0913, -0.1933, -0.6770],
         [ 0.0239,  0.0998, -0.1871, -0.0860, -0.4881, -1.6765,  0.2413,  0.7361,
           0.4608, -0.8722, -0.4259, -1.1347, -1.0571, -0.9401,  0.1343, -0.0157],
         [-0.2362, -0.7873, -0.3802,  0.5815, -0.3722,  1.2405, -0.7004, -1.4917,
           0.7678,  0.3584,  0.6120, -0.0794,  0.5983,  0.2635,  0.6

In [37]:
q = query(x) # (B, T, 16)
q.shape, q[0]

(torch.Size([4, 8, 16]),
 tensor([[-0.6567,  0.0283,  0.0094, -0.6995, -0.3604,  0.8376, -0.4446,  0.1228,
           0.6276, -0.6222,  0.3483,  0.2411,  0.5409, -0.2605,  0.3612, -0.0436],
         [-0.3932,  0.8220, -0.7027,  0.0954, -0.1222, -0.1518, -0.5024, -0.4636,
           0.1176,  1.4282, -0.5812,  0.1401,  0.9604,  0.0410, -0.6214, -0.6347],
         [ 0.2157, -0.3507,  0.0022,  0.4232, -0.2284, -0.0732, -0.3412,  0.9647,
          -0.5178,  0.0921, -0.5043,  0.8388,  0.6149, -0.0109, -0.5569,  0.5820],
         [ 0.9000, -0.1272,  0.5458,  0.4254, -0.4513, -0.0212,  0.1711,  0.2599,
          -0.9978,  0.4890,  0.1737, -0.0700, -0.3113,  0.3748, -0.1848, -0.6379],
         [ 0.0332,  0.5886, -0.4437,  0.3775, -0.6826, -0.2775,  0.4673, -1.2956,
           0.6603,  0.1633, -1.7573, -0.6582, -0.2302, -0.0862, -0.0060,  0.7573],
         [ 0.2098,  0.0439, -0.0702,  0.0727, -0.2012, -1.7539,  1.0369,  0.1163,
           0.2956,  0.3231,  0.5052,  0.7011, -0.2844, -0.7844,  0.4

In [38]:
wei =  q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)
wei.shape, wei[1]

(torch.Size([4, 8, 8]),
 tensor([[-0.7353, -1.7807,  1.0745, -0.2743,  1.6347,  1.4177, -0.5521, -2.3580],
         [-3.0892, -1.4943, -0.2617,  2.2760, -0.2436,  0.1620,  2.5783,  0.3959],
         [-0.5021, -2.0745,  0.5379, -0.4049,  0.8329,  1.3570, -1.5621, -1.6490],
         [ 1.3810, -0.1471,  1.2181, -0.2227, -1.8247, -3.7044, -2.1321,  1.3178],
         [-2.3568, -0.4617, -0.8820,  2.3700,  0.6783,  0.1626,  1.9379,  0.1040],
         [-0.9243, -0.6235, -1.3938,  1.3336, -0.0090, -3.1789,  0.9026,  3.6256],
         [-0.6552,  1.0991, -2.1399,  0.9647,  0.9946,  0.9390,  0.4680, -0.3587],
         [ 1.5463, -0.4944, -0.0142, -0.9743,  1.3779,  0.0079, -0.5359, -0.4553]],
        grad_fn=<SelectBackward0>))

In [39]:
tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei.shape, wei[3]

(torch.Size([4, 8, 8]),
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.6369, 0.3631, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2586, 0.7376, 0.0038, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.4692, 0.3440, 0.1237, 0.0631, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.1865, 0.4680, 0.0353, 0.1854, 0.1248, 0.0000, 0.0000, 0.0000],
         [0.0828, 0.7479, 0.0017, 0.0735, 0.0712, 0.0228, 0.0000, 0.0000],
         [0.0522, 0.0517, 0.0961, 0.0375, 0.1024, 0.5730, 0.0872, 0.0000],
         [0.0306, 0.2728, 0.0333, 0.1409, 0.1414, 0.0582, 0.0825, 0.2402]],
        grad_fn=<SelectBackward0>))

In [40]:
v = value(x)
v.shape, v[1]

(torch.Size([4, 8, 16]),
 tensor([[-1.3254e+00,  1.1236e+00,  2.2927e-01, -2.9970e-01, -7.6267e-03,
           7.9364e-01,  8.9581e-01,  3.9650e-01, -6.6613e-01, -2.1844e-01,
          -1.3539e+00,  4.1245e-01,  9.6011e-01, -1.0805e+00, -3.9751e-01,
          -4.4439e-01],
         [-1.9221e-01, -4.6449e-01,  5.9880e-02,  2.8408e-01, -1.0312e-01,
          -1.7967e-03,  1.8920e-01, -3.7337e-01, -9.8137e-02,  2.3116e-02,
           8.5743e-01,  5.6841e-01, -2.1939e-01, -2.9158e-01, -2.0158e-01,
          -4.6876e-01],
         [-1.1012e+00,  9.8266e-02,  5.8596e-01, -5.6413e-03,  3.7330e-01,
          -6.1363e-02,  2.8833e-02,  2.6230e-01,  6.4099e-01,  7.1003e-02,
           3.6877e-01,  5.0011e-01,  7.3872e-01,  1.1909e-01,  5.4246e-01,
           6.8950e-02],
         [ 4.9074e-01, -2.9978e-01,  1.0949e+00,  1.0131e+00,  3.5883e-01,
           9.5771e-01, -1.8349e-01,  1.4002e-01,  1.4243e-01,  8.0787e-01,
          -2.4476e-01,  1.3392e-01,  2.6700e-01,  3.2605e-01,  2.0296e-01,
   

In [41]:
out = wei @ v
#out = wei @ x

out.shape

torch.Size([4, 8, 16])

Notes:

* Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
* There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
* Each example across batch dimension is of course processed completely independently and never "talk" to each other
* In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
* "self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
* "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [42]:
k = torch.randn(B,T,head_size)
q = torch.randn(B,T,head_size)
wei = q @ k.transpose(-2, -1) * head_size**-0.5
wei

tensor([[[ 1.9857, -1.7370,  0.4189,  0.2985, -0.5451,  0.4942, -0.7267,
          -0.4810],
         [-0.6151,  0.7704,  0.1215,  0.1193, -1.0559, -0.1234,  0.3918,
          -0.2687],
         [ 0.4511,  0.6600,  0.8736, -0.5065,  0.8595,  0.2483,  0.7095,
          -0.3241],
         [ 1.4350, -0.5599,  1.2163, -0.0813,  1.7313,  0.3421, -0.3146,
          -0.9178],
         [-2.0204,  1.8716, -1.1214, -0.1317, -0.4320,  0.8461,  1.0991,
           1.8651],
         [ 1.0000,  0.5394,  0.9807, -0.0900,  0.7364,  1.3018,  1.4779,
           1.2385],
         [ 1.0542, -0.5249,  0.1258, -0.0781,  0.8236, -1.0546,  0.3601,
          -0.5679],
         [ 0.2587,  0.1620,  0.6471,  0.2837,  1.2641,  0.3890, -0.6218,
          -0.4601]],

        [[ 0.2958,  0.3852, -0.7456,  0.0486, -0.1722, -0.3054,  0.8299,
           0.4364],
         [-0.8550,  0.2635,  1.0761,  0.9544,  0.7529, -0.9505,  0.2712,
          -0.7474],
         [-0.9295, -0.1556, -0.0649, -0.3967, -0.1137, -1.0016,  0.5

# Model 2 + position encoding and linear layer

In [43]:
n_embed = 32
class Model2(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_emgeding = nn.Embedding(block_size, n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, target = None):
        B,T = idx.shape
        token_embed = self.token_embedding(idx)
        position_embed = self.position_emgeding(torch.arange(block_size))
        x = token_embed # + position_embed
        logits = self.lm_head(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            logits, loss = self(idx)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [44]:
model2 = Model2(vocab_size) 
optimizer2 = torch.optim.AdamW(model2.parameters(), lr=1e-3)

In [45]:
for i in range(1000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model2(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer2.step()
    if i%100 == 0:
        scores = evaluate_model(model2)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 4.5434, valid 4.5847
Loss train: 3.2354, valid 3.6357
Loss train: 2.8871, valid 3.3518
Loss train: 2.7644, valid 3.3977
Loss train: 2.7748, valid 3.4381
Loss train: 2.7295, valid 3.3741
Loss train: 2.7385, valid 3.4593
Loss train: 2.7196, valid 3.5232
Loss train: 2.6996, valid 3.8257
Loss train: 2.8132, valid 3.6644
3.261875867843628


# Single head attention

In [46]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, n_embd, head_size, dropout = 0):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

In [47]:
class Model3(nn.Module):
    "One attention head"
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_emgeding = nn.Embedding(block_size, n_embed)
        self.sa_head = Head(n_embed, n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, target = None):
        B,T = idx.shape
        token_embed = self.token_embedding(idx)
        position_embed = self.position_emgeding(torch.arange(T))
        x = token_embed + position_embed
        x = self.sa_head(x)
        logits = self.lm_head(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [48]:
model3 = Model3(vocab_size)
optimizer3 = torch.optim.AdamW(model3.parameters(), lr=1e-4)

In [49]:
for i in range(10000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model3(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer3.step()
    if i%1000 == 0:
        scores = evaluate_model(model3)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 4.4852, valid 4.4892
Loss train: 3.1141, valid 3.7331
Loss train: 2.8913, valid 3.3110
Loss train: 2.7057, valid 3.2623
Loss train: 2.6071, valid 3.2589
Loss train: 2.5828, valid 3.2975
Loss train: 2.6020, valid 3.1633
Loss train: 2.6052, valid 3.2499
Loss train: 2.6161, valid 3.3041
Loss train: 2.6136, valid 3.3110
2.629554271697998


In [50]:
print(model3.generate_text(1000))

.

earkar, herero l'thethe tg
" t winy h tthe ss helindeey ulnceh elyes be t s t
eer h t hey He f m in'

int
k st ve hag—t her e r himirthic t wosece st fof handath ie t
"Sthithe ag of theanthe thean nth f t s o he by; are he seid,"Is s."III
sstakee he.
onthid heroby own h, "Aveey h oke f t w 
"Rure t got ch hrot. sh thee he

f knt textre'st d ly  is ce wore fe'squler,
 ofellyle wint g a  winy il."
we urde t yeo fielas.

hon omimine s otht, or
t Re e
"edit t gineellon hid

d un taifotles ave't
c t t frmer wefouig t
fof ay h minet w
ly ssimaieit hereld a ar t s n, icha t d bet orede n hega eeeng s busoug, on hee s o sofue ng oto fofer,


l's ca orma et an ks hanoyshtys 's m. tlmiscar bye ss whereanthoorowh fe' sto wen ry fourextro helileg imeg," big asther
west as t rofems s's od tths.
oth g sh carerot an s h at cld a s."

"Ave
nsslle goy the sthimibeorae d t  orof hodl plmad triallung oimonde croor, heashris hent fon shot."llldse htt thas malllyemor be  becr; t r te h sthe'd s,

'sseat

# Multihead attention

In [51]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_embd, head_size, num_heads):
        super().__init__()
        self.heads = [Head(n_embd, head_size) for _ in range(num_heads)]
    def forward(self, x):
        return torch.cat([h(x) for h in self.heads], dim=-1)
        
        

In [52]:
class Model4(nn.Module):
    "Multihead"
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_emgeding = nn.Embedding(block_size, n_embed)
        self.sa_head = MultiHeadAttention(n_embed, n_embed//4, 4)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, target = None):
        B,T = idx.shape
        token_embed = self.token_embedding(idx)
        position_embed = self.position_emgeding(torch.arange(T))
        x = token_embed + position_embed
        x = self.sa_head(x)
        logits = self.lm_head(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [53]:
model4 = Model4(vocab_size)
optimizer4 = torch.optim.AdamW(model4.parameters(), lr=1e-4)

In [60]:
for i in range(10000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model4(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer4.step()
    if i%1000 == 0:
        scores = evaluate_model(model4)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 2.5529, valid 3.3762
Loss train: 2.5478, valid 3.3043
Loss train: 2.5842, valid 3.4124
Loss train: 2.5273, valid 3.5445
Loss train: 2.5023, valid 3.1714
Loss train: 2.4843, valid 3.3167
Loss train: 2.4943, valid 3.3946
Loss train: 2.4761, valid 3.4232
Loss train: 2.4848, valid 3.2340
Loss train: 2.4688, valid 3.2339
2.4572887420654297


In [61]:
print(model4.generate_text(1000))

.
' wose
wort maigousnmke evere l Jecoe afriwoipim, operwo tsho rf, ave an Is winglga riet hererea thim. sd Sor. Difier Hro?"
 al whfart cuman ouswi tflaer
ayve o yr uin wou, an ghaflind othe  chem imap goirs; amen, in. —hernm APR—————ghe y's in.
Bugot
rod wauit no?"
"Soe h e theenwh lak, wet, what ly nothe f wonuthot the nd Iss sof ris's
sithe toufr.

Ruuthirl, ar atois thinluer ad y ny thud ny ind ntwepende pateresend y woe efurdartof; onel!" r ochea fom." aI tanbpay, ary, andas of yan bthendoun indy ghe Igthe h Mas' fot day s Pl ghey ra f bforrmuy she —f Ila whse.


athum nd y whet areerd foeuninowthe f Hushe lousle ftor; es
sgot ares tnelal themary an he an Mal, day, f fowo rfof."
 It moa ld irlemat dad ecowreaye bwe he bals ky poome wereriver amer aofther. Ha w whainno. Nucuchellterswy soseo homourse doxaf, sotokec
o h [Isltha THE. Th. THhee tro
u farcose ak bryy-isnokou ptl; helsanoy olal gy iy fubiny  the swan pthain f woy uberutrtok, wh rauk rine.
Ye t, apterver.

M"Rby  Wougly

# Add Feedforward layer

In [54]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout = 0):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

In [55]:
class Model5(nn.Module):
    "+ Feedforward"
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_emgeding = nn.Embedding(block_size, n_embed)
        self.sa_head = MultiHeadAttention(n_embed, n_embed//4, 4)
        self.ff = FeedFoward(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, target = None):
        B,T = idx.shape
        token_embed = self.token_embedding(idx)
        position_embed = self.position_emgeding(torch.arange(T))
        x = token_embed + position_embed
        x = self.sa_head(x)
        x = self.ff(x)
        logits = self.lm_head(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [69]:
model5 = Model5(vocab_size)
optimizer5 = torch.optim.AdamW(model5.parameters(), lr=1e-4)

In [71]:
for i in range(10000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = model5(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer5.step()
    if i%1000 == 0:
        scores = evaluate_model(model5)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 2.5692, valid 3.1145
Loss train: 2.4836, valid 3.1331
Loss train: 2.4724, valid 3.1017
Loss train: 2.5154, valid 3.1178
Loss train: 2.4885, valid 3.0177
Loss train: 2.4765, valid 3.1613
Loss train: 2.4704, valid 3.2122
Loss train: 2.4608, valid 3.1457
Loss train: 2.4468, valid 3.1231
Loss train: 2.4278, valid 3.1349
2.4023804664611816


In [72]:
print(model5.generate_text(1000))

.
CGutecatt uit hinemr, aang Gacicllgucopang e she mutee forler bomeny ther ase aitlat itherir theang hif thest abypatos ther satits beroler albe aeg cutha m miteey
nint-ru wriblteyned ongttuo huge, r ashey thee adtakd "wit icitw me thay
iunoren stoif srwe. Nwe wod nallyk
taf hisn yohe shan'w the tit o Ge then lias overeren bepslow, ooud wat bobenod sshe coundict 1eith sen hee omy owo' bmof meou us th theerr; e lon as dof od wac ton'rogtthe the bithelap wing whaver
lit
eebetd wied sporacofto y, ba the they swo waly'en aft ase hrerspu oclleld fingg st'elt fiowlloro'P asens ur yhe hir, ang long ssi thany towe Wo chawl the ce actde. I Hoorngtoup, he chof to ma traze de beyfy fapa wro, dould k'lewwe
ghlitsting hee
she anacrrouf of sies t therealy bey thies, tree tichen'tte a aend co astllingf
eme
swas
coth med caslta hei e warhey ther ayid teme smeraeds ared herrey yound indingreren, wes," a arat sing the sbe to lery su mond, lelevoral, tayd  rook
dtoo herladnguS. Ats
B Hiteechene, hive ar

# Introduce blocks

In [75]:
class MultiHeadAttention2(nn.Module):
    def __init__(self, n_embd, head_size, num_heads):
        super().__init__()
        self.heads = [Head(n_embd, head_size) for _ in range(num_heads)]
        self.proj = nn.Linear(n_embd, n_embd)
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out
        
        

In [76]:
class Block(nn.Module):
    def __init__(self, n_embed, n_heads):
        super().__init__()
        self.heads = MultiHeadAttention2(n_embed, n_embed//n_heads, n_heads)
        self.ff = FeedFoward(n_embed)
    def forward(self, x):
        x = x + self.heads(x)
        out = x + self.ff(x)
        return x

In [77]:
class Model6(nn.Module):
    "+ blocks"
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_emgeding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(
            Block(n_embed, 4),
            Block(n_embed, 4),
            Block(n_embed, 4)
        )
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, idx, target = None):
        B,T = idx.shape
        token_embed = self.token_embedding(idx)
        position_embed = self.position_emgeding(torch.arange(T))
        x = token_embed + position_embed
        x = self.blocks(x)
        logits = self.lm_head(x)
        if target is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            target = target.view(B*T)
            loss = F.cross_entropy(logits, target)
            
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:,-block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:,-1,:] # Use only logtis from last token
            probs = F.softmax(logits, dim =-1)
            new_token = torch.multinomial(probs, num_samples = 1)
            idx = torch.cat((idx, new_token), dim=1)
        return idx

    def generate_text(self, max_tokens=100):
        prompt = torch.zeros([1, 1], dtype = torch.long)
        return "".join([untokenize(x) for x in self.generate(prompt, max_tokens).tolist()[0]])

In [78]:
model6 = Model6(vocab_size)
optimizer6 = torch.optim.AdamW(model6.parameters(), lr=1e-4)

In [84]:
m = model6
o = optimizer6
for i in range(10000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = m(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    o.step()
    if i%1000 == 0:
        scores = evaluate_model(m)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 2.4685, valid 3.4950
Loss train: 2.4905, valid 3.5867
Loss train: 2.4494, valid 3.7298
Loss train: 2.4862, valid 3.5445
Loss train: 2.4617, valid 3.4688
Loss train: 2.4670, valid 3.5610
Loss train: 2.4318, valid 3.4274
Loss train: 2.4597, valid 3.5924
Loss train: 2.5002, valid 3.8630
Loss train: 2.4351, valid 3.4708
2.476980447769165


In [83]:
print(m.generate_text(1000))

.S I brhem. 
irelo che neritha d her of and oulllang come cribe oe che o coll bat thee he sereris tss orely il. hag

ag

shae s hay oringhe hemang
Iif thele fof, rack nel? on "

he s hacre  ssael s e afit ereraing
ss kenou tomeler Te'

f Iinct aplofuld s ars" if to ngo ste awer ha m
he Het ale I'm heved 
 an' arcrand oo nthild harle wad, glin. t lort to  idshicther onl,, y
and ay, anll
he ithes Ar uing ae porp r'srerkr, an tarind t anthe e wanbes

as e it args olit ath tean cot ilir thes he ano t s athe ten,
and adh are. 
o
tor cche s l as" 

r an
tie

at 
ad p livegr tollo hees t allin cand ive'l was no d lf anthee " saek fr cay Hed

he Ye.

"Suped here hror, s oothe 
es hed laen s I boik ts. peand yoto'ie s.
s 
And
an 
an

acawicch she h 
a astin wan ad ive
a  grin mshe acp aneile


forch s rye fid, reis; onk ned mmit y 'sore in loris le hasom see oif.
hee Her'wane t shing heanbat coun wou, wnt roh she, 
irg
e se ughel se rif
ove n Huing la, bertm o

ag mar tol wate d ating a ton t r

# Final architecture (from Andrej)

In [85]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out


In [87]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out


In [88]:
class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)


In [89]:
class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedFoward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x


In [90]:
class Model7(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [97]:
n_embd = 32
dropout = 0.2
n_layer = 3
n_head = 4
device = 'cpu'

model7 = Model7()
optimizer7 = torch.optim.AdamW(model7.parameters(), lr=1e-4)

In [103]:
m = model7
o = optimizer7
for i in range(10000):
    X,Y = get_batch('train', batch_size=32)
    logits, loss = m(X, Y)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    o.step()
    if i%1000 == 0:
        scores = evaluate_model(m)
        print(f"Loss train: {scores['train']:.4f}, valid {scores['valid']:.4f}")
print(loss.item())

Loss train: 2.4040, valid 3.0002
Loss train: 2.3876, valid 2.9633
Loss train: 2.3316, valid 2.9719
Loss train: 2.3168, valid 3.0093
Loss train: 2.3058, valid 2.9829
Loss train: 2.2935, valid 2.8786
Loss train: 2.2994, valid 2.8954
Loss train: 2.2843, valid 2.9193
Loss train: 2.2710, valid 2.8215
Loss train: 2.2878, valid 2.8276
2.2491660118103027


In [104]:
max_tokens = 1000
prompt = torch.zeros([1, 1], dtype = torch.long)
print("".join([untokenize(x) for x in m.generate(prompt, max_tokens).tolist()[0]]))

., a ill sarrrolveravey high wacavideivagean'ser. Ishml cas sro prelthtw.
"Soventsineno dnce,
lifore she fof, and, wad  orom 
ave band margem qum pivearor re song to inricfro cras forlgat he shin hing
shaint shlouse freanthe, mored arvroe scaid bloof tren ut la wa'sexlllius wradeld, wals hererrer vlrn whe wyrreak ga thol! E
shlot ald thaspast sEme aglidead houe leeigh fiche hof the himahe poor as omerow ied delt lach pit lan my his rrouscesan mat leder erored asy thericrer thart medng acengas of
he wof hen ler his adeandark sowkee ated," thande hith sarenthe  bot her woup ing, vadroong. Whe heondrs
incexh forsninge mundloume to manoreh, of  the s. I slat apt me nin.
"
Isad Gufat lane hay bmim lan, eprand 
brewit hvah the gid thercen blys
hind fred was of wheyt hirec as hus rrot hatomll tipn ive rrid tohas ass shalyes te and, waln dof a "ores., Ho
Liwos plant has hing, droow dar sheftord hen tht arelaced ler no whe wa arnl sow.'”'l anakm rof lorkn woubrineg7"

aod flof wor drlrprroucte,