In [1]:
WORKING_DIRECTORY = "../.."

import os
import sys

filepath = os.path.join(WORKING_DIRECTORY, "data/freud/interpretation-of-dreams.txt")
input_file = open(filepath, 'r', encoding='utf-8')

In [2]:
raw_text = input_file.read()
input_file.close()

In [3]:
raw_characters = set(raw_text)
''.join(sorted(list(raw_characters)))

'\n !&(),-.0123456789:;=?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]_abcdefghijklmnopqrstuvwxyz·ÀÂÆÉÜàäæçèéêîïóôûüŒœ̓Ψένςυχ–—‘’“”\ufeff'

In [4]:
cleaned_text = raw_text.replace('\ufeff', '')
import re
cleaned_text = re.sub(r'^[ \t]+|[ \t]+$', '', cleaned_text, flags=re.MULTILINE)
cleaned_text = re.sub(r'\n{2,}', '\n\n', cleaned_text)
cleaned_text = re.sub(r'(?<!\n)\n(?!\n)', ' ', cleaned_text)
cleaned_text = re.sub(r' {2,}', ' ', cleaned_text)
cleaned_text = re.sub(r'\n\n', '\n', cleaned_text)
text = cleaned_text.strip()

In [5]:
from collections import Counter
print(Counter(text))
characters = sorted(list(set(text)))
vocab_size = len(characters)

Counter({' ': 185886, 'e': 112697, 't': 84033, 'a': 66289, 'o': 65011, 'i': 63896, 'n': 61272, 's': 59113, 'h': 53439, 'r': 52082, 'd': 30739, 'l': 30165, 'c': 29084, 'm': 25062, 'f': 22664, 'u': 22638, 'p': 16956, 'w': 15128, 'y': 14864, 'g': 13554, 'b': 11353, ',': 10012, '.': 7341, 'v': 7260, 'k': 3805, 'I': 3725, 'x': 2267, 'T': 2091, '_': 1872, '\n': 1237, ';': 1076, '“': 989, 'j': 987, '”': 983, 'A': 904, '-': 897, 'q': 826, '—': 676, '(': 646, ')': 646, 'S': 576, 'W': 560, '’': 547, 'B': 537, ':': 516, 'H': 514, 'F': 441, 'O': 432, 'M': 421, '[': 393, ']': 393, 'D': 344, 'E': 294, 'N': 294, 'R': 275, 'P': 269, '1': 249, 'L': 230, 'C': 223, '?': 203, 'G': 188, 'z': 185, '5': 155, 'é': 153, '3': 152, '6': 148, '2': 145, '8': 142, '4': 128, 'U': 115, 'J': 107, 'V': 97, '7': 79, 'K': 78, 'Y': 77, '0': 67, '9': 63, '&': 56, 'ü': 50, '!': 49, 'æ': 32, 'ê': 29, '‘': 26, 'œ': 22, 'Z': 21, 'à': 21, 'ô': 18, 'X': 16, 'Q': 14, 'è': 12, '=': 12, 'ä': 9, 'ç': 9, 'Ψ': 6, '–': 5, 'Ü': 3, '·': 

In [6]:
idx_to_token = dict(enumerate(characters))
token_to_idx = {t: i for i, t in enumerate(characters)}
encode = lambda s: list(map(token_to_idx.__getitem__, s))
decode = lambda s: list(map(idx_to_token.__getitem__, s))

In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)

In [8]:
# split into train / val
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8 # maximum context length
train_data[:block_size+1]
batch_size = 4 # how many we want to process in parallel

torch.manual_seed(42)
def get_batch(split):
		data = train_data if split == 'train' else val_data
		idxs = torch.randint(len(data) - block_size, (batch_size,))
		x = torch.stack([data[i:i+block_size] for i in idxs])
		y = torch.stack([data[i+1:i+block_size+1] for i in idxs])
		return x, y
# each block contains block_size examples of increasingly longer strings to provide recurrent training
# y is the next single character to predict
# block of 'ABCDEFG' -> [('A', 'B'), ('AB', 'C'), ('ABC', 'D')]
# batches for GPU parallelization
batch_x, batch_y = get_batch('train')
print(batch_x)
print(batch_y)
print(batch_x.shape)

tensor([[73, 60, 70, 60, 71,  1, 71, 66],
        [69, 58, 52, 65, 70,  8,  1, 31],
        [53, 56, 55, 70,  1, 52, 69, 56],
        [58,  1, 74, 60, 71, 59,  1, 57]])
tensor([[60, 70, 60, 71,  1, 71, 66,  1],
        [58, 52, 65, 70,  8,  1, 31, 57],
        [56, 55, 70,  1, 52, 69, 56,  1],
        [ 1, 74, 60, 71, 59,  1, 57, 60]])
torch.Size([4, 8])


# BIGRAM MODEL

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# note: both x and y are of shape (batch size, block size)
# B: batch size
# T: block size
# V: vocab size (channel size)

# B -> batch
# T -> time dimension
# C -> channel

class BigramLanguageModel(nn.Module):
	def __init__(self, vocab_size):
		super().__init__()
		# nn.Embedding works as a lookup table from token -> token
		self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)
	def forward(self, idx, targets=None):
		# looks up each idx in the embedding table
		# since input idx (x) is of size (B,T), returns (B,T,V) tensor
		# (since each word corresponds with a vector size vocab_size)
		logits = self.token_embedding_table(idx)

		if targets is None:
			return logits, None

		# use negative log likelihood (cross entropy)
		# F.cross_entropy expects logits to be 2D and targets 1D, so collapse dimensions
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)

		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			logits, _ = self(idx)
			logits = logits[:, -1, :] # specifically only get the LAST timestep
			probs = F.softmax(logits, dim = -1) # over the channel dimension, size (B, 1)
			idx_next = torch.multinomial(probs, num_samples = 1) # sample, size (B, 1)
			idx = torch.cat((idx, idx_next), dim=1) # append to given indices (B, T + 1)
		return idx

m = BigramLanguageModel(vocab_size)
logits, loss = m(batch_x, batch_y)
print(logits.shape)
print(loss)
print(decode(
	m.generate(
		torch.zeros((1, 1), dtype=torch.long),
		100
	)[0].tolist()
))

torch.Size([32, 112])
tensor(5.1489, grad_fn=<NllLossBackward0>)
['\n', 'W', 'œ', 'Ü', '8', 'R', 'î', 'S', '8', '“', 't', 'h', 'Y', 'û', 'Œ', 'A', 'ï', 'Â', 'D', 'χ', 'Æ', ')', 'Ψ', '–', 'T', 'I', 'h', 'U', 'ô', 'e', 'è', 'X', 'Æ', 'k', ' ', '—', 'w', '!', '“', '_', 'H', 'ï', 'c', 'I', 'b', 'û', 'ü', '“', '_', '5', 'S', '6', '5', 'h', '9', '6', 'v', 'j', 'ç', 'Y', '‘', 'Ψ', 'h', 'j', 'έ', 'î', '.', '!', '&', ',', 'C', '.', 'D', 'χ', '\n', 'y', 'T', 'A', 'Œ', '7', '1', 'D', 'ô', ',', 'H', 'w', ' ', 'Ψ', 'P', 'r', 't', 'b', 'ó', 'w', 'o', 'u', 'C', '5', 'R', 'R', 'Â']


In [11]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
batch_size = 32
for steps in range(1000):
	xb, yb = get_batch('train')
	logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
print(loss.item())

4.037664890289307


In [None]:
print("".join(decode(
	m.generate(
		torch.zeros((1, 1), dtype=torch.long),
		500
	)[0].tolist()
)))


hdF”ô2d=Lç07;stqν——PûLBRHsnÜ8dY_νD:0ûÉ–m6y qÀ—4B(YNêlàςυ·éqoôô2küXEov[fVun’0”BoûN3kak)έ7LiΨ–?MäkέóÂô_45 w89x4—k3êΨ7A‘ s-vvéXiîa3·nυNÆqNWOÜ!wàMÂM.!.ijïcæ?vQMέ
p. akFdWg—YP3χ[xB7ûŒ-whvôîæÆuςBl19î0ïURΨFêΨυ(_&äRΨJCPh,_SXÉ;t0ŒH0Ü!OvkæD7eqn(4M4‘χrk3;
ê12TcuOnp
OÉX–u—?,CCXTνûÉ“IvSaυ_Â&Mg0LXæ—wc3D.v&P7Lubïc
9,QNéW0CtXχ6v)f:JeyDPêχu—R3UÆΨυJea3feSüzûT8g-siC—wiv“iïUP:hdCEHûN’bw—9RïÀ7LFGc‘ tυp-oυRyÉGQOYoυRazDa—&,. P[ZYY612VC_æ=;lFAiî5?M62“SpUx6lun a99,8_üCKΨυ‘Æ3.υ·PN
5à?SuZadÆÂ?cu·ô)χ̓ÉP):ükQ&,ÜJMxVXccÂUcee


# Attention testing
note that to couple tokens with previous tokens in the context, it may be valuable to use the 'average' of the previous tokens, named a 'bag of words'

the trick to doing this extremely fast (without having to loop over B and T and averaging over time) is to use matrix multiplication

multiplying by a ones matrix returns the sum of each row / column

=> use a lower triangular matrix of ones to find these values

```python
weights = torch.tril(torch.ones(T, T))
weights = weights / weights.sum(1, keepdim=True)
xbow2 = weights @ x # note this is (T, T) x (B, T, C), but torch will cast weights into (B, T, T) resulting in (B, T, C)
```

In [None]:
n_embed = 32
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# this one doesn't work i think but you can see what we're doing here, adding a linear layer inbetween

class BigramLanguageModelWithPositionalEncoding(nn.Module):
	def __init__(self, vocab_size=vocab_size, n_embed=n_embed, block_size=block_size):
		super().__init__()
		self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
		self.position_embedding_table = nn.Embedding(block_size, n_embed)
		self.lm_head = nn.Linear(n_embed, vocab_size)

	def forward(self, idx, targets=None):
		tok_embed = self.token_embedding_table(idx)
		_, T, _ = tok_embed.shape
		pos_embed = self.position_embedding_table(torch.arange(T))
		x = tok_embed + pos_embed
		logits = self.lm_head(x)

		if targets is None:
			return logits, None
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)
		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			logits, _ = self(idx)
			logits = logits[:, -1, :] # specifically only get the LAST timestep
			probs = F.softmax(logits, dim = -1) # over the channel dimension, size (B, 1)
			idx_next = torch.multinomial(probs, num_samples = 1) # sample, size (B, 1)
			idx = torch.cat((idx, idx_next), dim=1) # append to given indices (B, T + 1)
		return idx
	
m = BigramLanguageModel(vocab_size)
logits, loss = m(batch_x, batch_y)
print(logits.shape)


torch.Size([32, 112])


In [15]:
# head test
B, T, C = 4, 8, 32
x = torch.randn(B,T,C)

head_size=16
key=nn.Linear(C, head_size, bias=False)
query=nn.Linear(C, head_size, bias=False)
value=nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
weights = q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))
# weights = torch.zeros((T, T))
weights = weights.masked_fill(tril == 0, float('-inf'))
weights = F.softmax(weights, dim=-1)

v = value(x)
out = weights @ v

In [16]:
weights[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2239, 0.7761, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8172, 0.1650, 0.0179, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0411, 0.9134, 0.0296, 0.0159, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0024, 0.0866, 0.6325, 0.0204, 0.2582, 0.0000, 0.0000, 0.0000],
        [0.4754, 0.2132, 0.0254, 0.2568, 0.0079, 0.0212, 0.0000, 0.0000],
        [0.1016, 0.3475, 0.0438, 0.3245, 0.0951, 0.0521, 0.0354, 0.0000],
        [0.4936, 0.1528, 0.0470, 0.0935, 0.0416, 0.0253, 0.1003, 0.0459]],
       grad_fn=<SelectBackward0>)

notes:

q "what am i looking for?"
k "what can i offer?"
v "what i actually offer during attention"

attention is the KVQ mechanism here. attention is a communication mechanism and works in directed graphs. technically it doesn't have to be a DAG and can work with any directed graph. the graphs we use in this model are where each token points to all tokens after it in the batch, and the last token points to itself.

note that also in attention there isn't a notion of space here, it just works on a group of vectors. we have to add the positional encoding for the model to care about it

each example in a batch does not care about anything about the other batches; completely independent

in an encoder it does not matter that tokens can communicate with tokens later in a sequence (meaning no masking above). sequences where that does matter is called a 'decoder', like so, and is commonly used in similarly autoregressive settings. attention does not matter
(auto-regressive: statistical models that predicts future values in a time series using previous values)

"self-attention" -> the keys and values are produced from the same source as the queries. 

"cross-attention" -> keys and values are produced from a difference source as the queries. seems to be used when we want to condition off of some other context or use a different modality. eg. for translation from languages A to B, k/v pairs are generated for tokens in A, then are queried using query vectors from B.

in the original "attention is all you need" paper, before softmaxing the weights are multiplied by $\frac{1}{\sqrt{\texttt{head\_size}}}$. this is because when tensor multiplying k and v, the variance of the output function gets amplified tremendously. and since softmax favors the largest values, with greater variance this will increase the disparity between the largest value and the rest of the values, essentially creating one-hot vectors; which we do not want since we want to maximize the usage of our embedding dimensions.

# Inserting a head into our model

In [None]:
class Head(nn.Module):
	def __init__(self, head_size=head_size, n_embed=n_embed, block_size=block_size):
		super().__init__()
		self.key = nn.Linear(n_embed, head_size, bias=False)
		self.query = nn.Linear(n_embed, head_size, bias=False)
		self.value = nn.Linear(n_embed, head_size, bias=False)
		self.register_buffer('tril', 
			torch.tril(torch.ones(block_size, block_size))
		)
		self.attention_scalar = pow(head_size, -0.5)
	
	def forward(self, x):
		_, T, _ = x.shape
		k = self.key(x)
		v = self.value(x)
		q = self.query(x)

		# doing attn. formula -> softmax(qK^T / sqrt(d_k)) * V
		weights = q @ k.transpose(-2, -1) * self.attention_scalar
		weights = F.softmax(
			weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')),
			dim=-1
		)
		out = weights @ v
		return out

class SingleHeadAttentionLanguageModel(nn.Module):
	def __init__(self, vocab_size=vocab_size, n_embed=n_embed, block_size=block_size, head_size=head_size):
		super().__init__()

		self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
		self.position_embedding_table = nn.Embedding(block_size, n_embed)
		self.sa_head = Head(head_size=n_embed, n_embed=n_embed, block_size=block_size)
		self.lm_head = nn.Linear(n_embed, vocab_size)

	def forward(self, idx, targets=None):
		B, T = idx.shape
		
		tok_embed = self.token_embedding_table(idx)
		pos_embed = self.position_embedding_table(torch.arange(T, device=device))
		x = self.sa_head(tok_embed + pos_embed)
		logits = self.lm_head(x)

		if targets is None:
			return logits, None
		
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)
		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			idx_cond = idx[:, -block_size:] # we can only use context length of block_size otherwise positional embeddings will fail
			logits, _ = self(idx_cond)
			logits = logits[:, -1, :]
			probs = F.softmax(logits, dim = -1)
			idx_next = torch.multinomial(probs, num_samples = 1)
			idx = torch.cat((idx, idx_next), dim=1)
		return idx

In [18]:
m = SingleHeadAttentionLanguageModel().to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
for steps in range(100):
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
print(loss.item())

2.8626787662506104


In [None]:
print("".join(decode(
	m.generate(
		torch.zeros((1, 1), dtype=torch.long, device=device),
		500
	)[0].tolist()
)))


-e lw n e‘ypchdpPgdyh shrsn n n vd ur gberuhrrnddif.cttsoo  uoexanna hae2hsos 6mheoaeceseaerd9 ncf rbrsnhibenyntl ynispmltiabnptt i msaIyto or]rchao iwoilslorse_d n nrshramil kyaisi  ss tooiJnwh 6miNttnaettœm f idoholoo  tft oece ierr  yse.m eahruasth yi tihyerouv daerfäiowwuu gfemoon nf aoffb eytaosatevntttd dya nal i iwatlyeIrh maenehn  tnpti hέtgiiteνoo6ta asef ain reOoswtan nnhi=e toot ha tine’pn s rtf, co B‘2g rw n.eaerie, e t w ioahfsnn iy tumidiarhiioa e7foeaarrbtraa t dg tisn  iotm=eeke 


In [None]:
class MultiHeadAttention(nn.Module):
	def __init__(self, num_heads, head_size):
		super().__init__()
		self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
	def forward(self, x):
		return torch.cat([h(x) for h in self.heads], dim = -1)

class FeedForward(nn.Module):
	def __init__(self, n_embed):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(n_embed, n_embed),
			nn.ReLU()
		)
	def forward(self, x):
		return self.net(x)

class MultiHeadAttentionLanguageModel(nn.Module):
	def __init__(self, vocab_size=vocab_size, n_embed=n_embed, block_size=block_size, head_size=head_size):
		super().__init__()

		self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
		self.position_embedding_table = nn.Embedding(block_size, n_embed)
		self.sa_heads = MultiHeadAttention(num_heads=4, head_size=n_embed//4)
		self.ffwd = FeedForward(n_embed)
		self.lm_head = nn.Linear(n_embed, vocab_size)

	def forward(self, idx, targets=None):
		B, T = idx.shape
		
		tok_embed = self.token_embedding_table(idx)
		pos_embed = self.position_embedding_table(torch.arange(T, device=device))
		x = self.sa_heads(tok_embed + pos_embed)
		x = self.ffwd(x)
		logits = self.lm_head(x)

		if targets is None:
			return logits, None
		
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)
		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			idx_cond = idx[:, -block_size:] # we can only use context length of block_size otherwise positional embeddings will fail
			logits, _ = self(idx_cond)
			logits = logits[:, -1, :]
			probs = F.softmax(logits, dim = -1)
			idx_next = torch.multinomial(probs, num_samples = 1)
			idx = torch.cat((idx, idx_next), dim=1)
		return idx


In [22]:
m = MultiHeadAttentionLanguageModel().to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

In [None]:
for steps in range(1000):
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
print(loss.item())

1.9707434177398682


In [None]:
print("".join(decode(
	m.generate(
		torch.zeros((1, 1), dtype=torch.long, device=device),
		100
	)[0].tolist()
)))


has to mas to meem to ther inges foriling the wish extencabjempearmouve stom in in of the im, and th


In [None]:
class MultiHeadAttention(nn.Module):
	def __init__(self, num_heads, head_size):
		super().__init__()
		self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
		self.proj = nn.Linear(n_embed, n_embed)
	def forward(self, x):
		out = torch.cat([h(x) for h in self.heads], dim = -1)
		out = self.proj(out)
		return out

class FeedForward(nn.Module):
	def __init__(self, n_embed):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(n_embed, 4 * n_embed),
			nn.ReLU(),
			nn.Linear(4 * n_embed, n_embed),
		)
	def forward(self, x):
		return self.net(x)

class Block(nn.Module):
	def __init__(self, n_embed, n_head):
		super().__init__()
		head_size = n_embed // n_head
		self.sa = MultiHeadAttention(n_head, head_size)
		self.ffwd = FeedForward(n_embed)
	
	def forward(self, x):
		x = x + self.sa(x)
		x = x + self.ffwd(x)
		return x

# add blocks + residual connections
class AttentionBlockLanguageModel(nn.Module):
	def __init__(self, vocab_size=vocab_size, n_embed=n_embed, block_size=block_size, head_size=head_size):
		super().__init__()

		self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
		self.position_embedding_table = nn.Embedding(block_size, n_embed)
		self.blocks = nn.Sequential(
			Block(n_embed, n_head=4),
			Block(n_embed, n_head=4),
			Block(n_embed, n_head=4),
		)
		self.lm_head = nn.Linear(n_embed, vocab_size)

	def forward(self, idx, targets=None):
		B, T = idx.shape
		
		tok_embed = self.token_embedding_table(idx)
		pos_embed = self.position_embedding_table(torch.arange(T, device=device))
		x = self.blocks(tok_embed + pos_embed)
		logits = self.lm_head(x)

		if targets is None:
			return logits, None
		
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)
		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			idx_cond = idx[:, -block_size:] # we can only use context length of block_size otherwise positional embeddings will fail
			logits, _ = self(idx_cond)
			logits = logits[:, -1, :]
			probs = F.softmax(logits, dim = -1)
			idx_next = torch.multinomial(probs, num_samples = 1)
			idx = torch.cat((idx, idx_next), dim=1)
		return idx

In [None]:
m = MultiHeadAttentionLanguageModel().to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-4)

In [None]:
for steps in range(1000):
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
print(loss.item())

1.8985002040863037


In [None]:
print("".join(decode(
	m.generate(
		torch.zeros((1, 1), dtype=torch.long, device=device),
		100
	)[0].tolist()
)))


If a drefond _Sens ba the coftecaninercaw,_lortratege the boltel cariseves ally dreath ot evucarin b


# Adding Layernorm + Dropout

In [51]:
batch_size=64
block_size=64
n_embed=64

dropout=0.2

class Head(nn.Module):
	def __init__(self, head_size=head_size, n_embed=n_embed, block_size=block_size):
		super().__init__()
		self.key = nn.Linear(n_embed, head_size, bias=False)
		self.query = nn.Linear(n_embed, head_size, bias=False)
		self.value = nn.Linear(n_embed, head_size, bias=False)
		self.dropout = nn.Dropout(dropout)
		self.register_buffer('tril', 
			torch.tril(torch.ones(block_size, block_size))
		)
		self.attention_scalar = pow(head_size, -0.5)
	
	def forward(self, x):
		_, T, _ = x.shape
		k = self.key(x)
		v = self.value(x)
		q = self.query(x)

		# doing attn. formula -> softmax(qK^T / sqrt(d_k)) * V
		weights = q @ k.transpose(-2, -1) * self.attention_scalar
		weights = F.softmax(
			weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')),
			dim=-1
		)
		weights = self.dropout(weights)
		out = weights @ v
		return out

class MultiHeadAttention(nn.Module):
	def __init__(self, num_heads, head_size):
		super().__init__()
		self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
		self.proj = nn.Linear(n_embed, n_embed)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x):
		out = torch.cat([h(x) for h in self.heads], dim = -1)
		out = self.dropout(self.proj(out))
		return out

class FeedForward(nn.Module):
	def __init__(self, n_embed):
		super().__init__()
		self.net = nn.Sequential(
			nn.Linear(n_embed, 4 * n_embed),
			nn.ReLU(),
			nn.Linear(4 * n_embed, n_embed),
			nn.Dropout(dropout)
		)
	def forward(self, x):
		return self.net(x)

class Block(nn.Module):
	def __init__(self, n_embed, n_head):
		super().__init__()
		head_size = n_embed // n_head
		self.sa = MultiHeadAttention(n_head, head_size)
		self.ffwd = FeedForward(n_embed)
		self.ln1 = nn.LayerNorm(n_embed)
		self.ln2 = nn.LayerNorm(n_embed)

	def forward(self, x):
		x = x + self.sa(self.ln1(x))
		x = x + self.ffwd(self.ln2(x))
		return x

# add blocks + residual connections
class LayernormLanguageModel(nn.Module):
	def __init__(self, vocab_size=vocab_size, n_embed=n_embed, block_size=block_size, head_size=head_size):
		super().__init__()

		self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
		self.position_embedding_table = nn.Embedding(block_size, n_embed)
		self.blocks = nn.Sequential(
			Block(n_embed, n_head=4),
			Block(n_embed, n_head=4),
			Block(n_embed, n_head=4),
			Block(n_embed, n_head=4),
			nn.LayerNorm(n_embed),
		)
		self.lm_head = nn.Linear(n_embed, vocab_size)

	def forward(self, idx, targets=None):
		B, T = idx.shape
		
		tok_embed = self.token_embedding_table(idx)
		pos_embed = self.position_embedding_table(torch.arange(T, device=device))
		x = self.blocks(tok_embed + pos_embed)
		logits = self.lm_head(x)

		if targets is None:
			return logits, None
		
		B, T, C = logits.shape
		logits_shrink = logits.view(B*T, C)
		targets_shrink = targets.view(B*T)
		loss = F.cross_entropy(logits_shrink, targets_shrink)
		return logits_shrink, loss
	
	def generate(self, idx, max_new_tokens):
		for _ in range(max_new_tokens):
			idx_cond = idx[:, -block_size:] # we can only use context length of block_size otherwise positional embeddings will fail
			logits, _ = self(idx_cond)
			logits = logits[:, -1, :]
			probs = F.softmax(logits, dim = -1)
			idx_next = torch.multinomial(probs, num_samples = 1)
			idx = torch.cat((idx, idx_next), dim=1)
		return idx

In [None]:
m = LayernormLanguageModel().to(device)
optimizer = torch.optim.AdamW(m.parameters(), lr=3e-4)

In [72]:
for steps in range(1000):
	xb, yb = get_batch('train')
	xb = xb.to(device)
	yb = yb.to(device)
	logits, loss = m(xb, yb)
	optimizer.zero_grad(set_to_none=True)
	loss.backward()
	optimizer.step()
	if steps % 100 == 0:
		print(f"{steps}: {loss.item()}")

0: 1.74336576461792
100: 1.773617148399353
200: 1.8295972347259521
300: 1.6838399171829224
400: 1.7654995918273926
500: 1.7980626821517944
600: 1.7672405242919922
700: 1.777510166168213
800: 1.7262349128723145
900: 1.706779956817627


In [None]:
print("".join(decode(
	m.generate(
		torch.tensor([encode("The mind")], dtype=torch.long, device=device),
		100
	)[0].tolist()
)))

The mind we in det, are and to hove consuble, whichich ho fromewnd, shown is shave alfore is a diblections: 


# QUESTIONS
- What is the reason for the k / v / q vectors?
- What is the reason for multiple heads? Why does this improve gains, over a single head?
- Why do we have blocks? What is the optimal choice in the number of blocks?
- Why 'feed-forward'? Is there a reason it is this simple?
- How to determine an optimal C / channel size?
- Why do residual block networks gain speedups?
- Why do you need those projection layers if it is the same dim going in and out?