In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F


### Hyperparameters

In [2]:
'''Hyperparameters for smaller model'''

B = 32 # B: how many independent sequences will we process in parallel?
T = 8  # T: what is the maximum context length for predictions?
C = 32 # C: numer of different features analysed (also D = dims)
H = 4  # H: number of attention heads
L = 4  # L: Number of layers
learning_rate = 1e-3

'''Final Hyperparameters'''

# B = 64 # B: how many independent sequences will we process in parallel?
# T = 256  # T: what is the maximum context length for predictions?
# H = 6
# C = 64*H
# L = 6
# learning_rate = 1e-4

# Common Hyperparameters
max_iters = 5000
eval_interval = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
dropout = 0.2
torch.manual_seed(1337)

<torch._C.Generator at 0x1cc20fc1050>

### Data

In [6]:
import requests

url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)

with open("input.txt", "wb") as file:
    file.write(response.content)


In [14]:
import re

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# Split the text into words using space, newline, or double newline as delimiters
words = re.split(r'( |\n|\n\n)', text)

# Remove empty strings from the list
words = [word for word in words if word]

# Create a set of unique words (including spaces and newlines)
unique_words = sorted(set(words))
vocab_size = len(unique_words)

# Create a mapping from words to integers
stoi = {word: i for i, word in enumerate(unique_words)}
itos = {i: word for i, word in enumerate(unique_words)}

# Encoder: take a string, output a list of integers
encode = lambda s: [stoi[word] for word in re.split(r'( |\n|\n\n)', s) if word]

# Decoder: take a list of integers, output a string
decode = lambda l: ''.join([itos[i] for i in l])

# Combine the unique words to create the vocabulary string
vocab_str = ''.join(unique_words)

print(f'vocab_size: {vocab_size}')
print(f'vocabulary: {vocab_str}')


# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - T, (B,))
    x = torch.stack([data[i:i+T] for i in ix])
    y = torch.stack([data[i+1:i+T+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y


vocab_size: 25672
vocabulary: 


### Head, MHSA

In [17]:
class MultiHeadAttention(nn.Module):
    def __init__(self, Ci, H, head_size):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = H
        self.head_dim = head_size
        self.embed_size = Ci
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(Ci, Ci)

    def forward(self, values, keys, query, mask=None):
        N = query.shape[0]
        Q_len, K_len, V_len = query.shape[1], keys.shape[1], values.shape[1]

        # Split embedding into multiple heads
        values = values.reshape(N, V_len, self.num_heads, self.head_dim)
        keys = keys.reshape(N, K_len, self.num_heads, self.head_dim)
        queries = query.reshape(N, Q_len, self.num_heads, self.head_dim)

        values = values.permute(0, 2, 1, 3)
        keys = keys.permute(0, 2, 1, 3)
        queries = queries.permute(0, 2, 1, 3)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = F.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values])
        out = out.reshape(N, Q_len, self.embed_size)
        out = self.fc_out(out)
        return out

In [18]:
import torch.nn as nn

class Block(nn.Module):
    ''' Transformer block: communication followed by computation '''

    def __init__(self, C, H, dropout=0.1): # C: embedding dimension, H: number of heads
        super().__init__()
        self.ln1 = nn.LayerNorm(C)   # Layernorm along channels (batch & time are batch dims)
        self.sa = MultiHeadAttention(Ci=C, H=H, head_size=C//H)  # Note: Use Ci=C instead of C
        self.ln2 = nn.LayerNorm(C)
        self.ffwd = nn.Sequential(         # Feedforward network
            nn.Linear(C, C*4),
            nn.GELU(),
            nn.Linear(C*4, C),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Residual connections around MSA & FF
        x_skip = x

        x = self.ln1(x)
        x = self.sa(x, x, x)   # (B,T,C), Multi-head self-attention
        x = x + x_skip

        x_skip = x  # Update the skip connection here for the next residual
        x = self.ln2(x)
        x = self.ffwd(x) # (B,T,C), Per token level
        x = x + x_skip

        return x


### Model

In [20]:
class BigramLanguageModel(nn.Module):
    def __init__(self, B, T, C, H, L):
        super().__init__()
        self.B, self.T, self.C, self.H, self.L = B, T, C, H, L
        self.token_embedding_table = nn.Embedding(vocab_size, C)
        self.position_embedding_table = nn.Embedding(T, C)
        self.blocks = nn.Sequential(*[Block(C, H) for _ in range(L)])
        self.ln_final = nn.LayerNorm(C)
        self.lm_head = nn.Linear(C, vocab_size)

    def forward(self, idx, targets=None):
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(self.T, device=device))
        x = tok_emb + pos_emb
        x = self.blocks(x)
        x = self.ln_final(x)
        logits = self.lm_head(x)
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.T:]
            logits, _ = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx
        
model = BigramLanguageModel(B,T,C,H,L)
m = model.to(device)

#### Training

In [21]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
    if iter % eval_interval == 0:   # every once in a while evaluate the loss on train and val sets
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    xb, yb = get_batch('train')     # sample a batch of data

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 10.3317, val loss 10.3308
step 500: train loss 3.7517, val loss 3.8794
step 1000: train loss 3.5048, val loss 3.6977
step 1500: train loss 3.2773, val loss 3.4970
step 2000: train loss 2.9913, val loss 3.2262
step 2500: train loss 2.7580, val loss 3.0269
step 3000: train loss 2.5791, val loss 2.8509
step 3500: train loss 2.3699, val loss 2.7040
step 4000: train loss 2.2078, val loss 2.5883
step 4500: train loss 2.0620, val loss 2.4780


In [22]:
context = torch.ones((1, T), dtype=torch.long, device=device)  # start with '\n\n\n\n' as seed
out_ints = m.generate(context, max_new_tokens=2000)[0].tolist() # output list of ints


In [23]:
print(decode(out_ints))

        our bedrench when me mutinous
For Sorely, strong, do Rome to along
And one marriage,
Indeed, very Potpan! comforting
Here be that servant new thou make art
Thy kindling
How fair? his tomorrow? i' of themselves
What him all will You is his
Be king grief,

BUCKINGHAM:
If you publicly: less
Provoked trouble himself: Right, should
I course surmise,

FLORIZEL:
Were disloyal; thou knave! quoifs
I as traitor indite braver your let tune,
If wind-shaken. would am be bugs.
remedy; and look. me of VINCENTIO:

ESCALUS: Cominius, be affairs we
Which stand Hercules,
Under of when daughter enough: of humble debt this
dispraise kindred such sorrows a have me lady!
to some he hear ABHORSON:
And of she passing boy, in the tongue's
Can you such accusation 'The before devil's

RIVERS: he yet tell all fault bawdy are rogue. be father you every take

CLARENCE:
Proud word
Here doth will think people one BOLINGBROKE:

Her
As VINCENTIO:
to well I weak vow'd
After
POLIXENES:
You, soul;
Ay, stout many an