In [1]:
# torch imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import multinomial
import time; import pandas as pd

import matplotlib.pyplot as plt
from tqdm import tqdm

import requests
import os
import pdb

torch.manual_seed(305)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Global hyperparameters
SMALL_ITERS = 2000
LARGE_ITERS = 2000
EVAL_ITERS = 100
CONTEXT_WINDOW_SIZE = 256

In [2]:
# download the tiny shakespeare dataset
input_file_path = "input.txt"
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
    
# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)

In [3]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }

def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_chars = data[:int(n*0.9)]
val_chars = data[int(n*0.9):]

# encode both to integers
train_data = encode(train_chars)
val_data = encode(val_chars)

# cast as torch tensors
train_data = torch.tensor(train_data)
val_data = torch.tensor(val_data)

In [4]:
# function for getting batches of data
def get_batch(split, context_window_size, device, batch_size=32):
    """
    generate a small batch of data of inputs x and targets y

    Args:
        split: 'train' or 'val'
        device: 'cpu' or 'cuda' (should be 'cuda' if available)
    """
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_window_size, (batch_size,))
    x = torch.stack([data[i:i+context_window_size] for i in ix])
    y = torch.stack([data[i+1:i+context_window_size+1] for i in ix])
    x = x.to(device)
    y = y.to(device)
    return x, y

# helper function for tracking loss during training
# given to you
@torch.no_grad()
def estimate_loss(model, eval_iters, context_window_size, device):
    """
    Args:
      model: model being evaluated
      eval_iters: number of batches to average over
      context_window_size: size of the context window
      device: 'cpu' or 'cuda' (should be 'cuda' if available)
    """
    out = {}
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split, context_window_size, device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    return out

In [5]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size, context_window_size, embed_size=384):
        """
        Args:
          head_size: int, size of the head embedding dimension (K)
          context_window_size: int, number of tokens considered in the past for attention (T)
          embed_size: int, size of the token embedding dimension (D)
        """
        super().__init__()
        self.head_size = torch.tensor(head_size)
        self.key = nn.Linear(embed_size, head_size, bias=False)
        self.query = nn.Linear(embed_size, head_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)

        # not a param of the model, so registered as a buffer
        self.register_buffer('tril', torch.tril(
            torch.ones(context_window_size, context_window_size)))

    def forward(self, x):
        """
        Args:
          x: (B,T,D) tensor of token embeddings

        Returns:
          (B,T,D) tensor of attention-weighted token embeddings
        """
        # TODO: your code here
        
        # 0. get the shape of x (will be important during inference)
        B, T, D = x.shape
        
        # 1. X U_q^T @ U_k X^T
        output = self.query(x) @ self.key(x).mT
        
        # 2. apply causal mask and divide by sqrt(K) - for inference later, we need to truncate this
        output = output.masked_fill(self.tril[:T,:T] == 0.0, float("-inf")) / (self.head_size ** 0.5)
        
        # 3. apply softmax-across-rows
        output = torch.softmax(output, dim=2)
        
        # 4. multiply by XV^T and return
        return output @ self.value(x)

In [6]:
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, context_window_size, num_heads, head_size, embed_size=384):
        """
        Args:
            context_window_size: int, number of tokens considered in the past for attention (T)
            num_heads: int, number of heads (H)
            head_size: int, size of the head embedding dimension
            embed_size: int, size of the token embedding dimension
        """
        super().__init__()
        # TODO, your code below
        self.heads = nn.ModuleList(
            [Head(head_size, context_window_size, embed_size) for _ in range(num_heads)])

    def forward(self, x):
        # TODO, your code below
        
        # evaluate each head, pancake stack + sum
        output = torch.stack([head(x) for head in self.heads], axis=0).sum(axis=0)
        return output

In [7]:
# run this cell to initialize this deep learning module that you should use in the code your write later
# you don't need to edit this layer
class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity
        Given to you, you don't need to write any code here!
    """

    def __init__(self, embed_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_size, 4 * embed_size),
            nn.ReLU(),
            nn.Linear(4 * embed_size, embed_size),
        )

    def forward(self, x):
        return self.net(x)

In [8]:
class TransformerBlock(nn.Module):
    """ Transformer block: communication across sequence length, followed by communication across embedding space
        Uses multi-headed attention
    """

    def __init__(self, vocab_size, context_window_size, embed_size=384, num_heads=6):
        super().__init__()
        self.ln1 = nn.LayerNorm(embed_size)
        self.ln2 = nn.LayerNorm(embed_size)

        # TODO: your code below
        head_size = embed_size // num_heads
        self.feed_forward = FeedForward(embed_size)
        self.atten_heads = MultiHeadAttention(context_window_size, num_heads, head_size, embed_size)

    def forward(self, x):
        x = x + self.atten_heads(self.ln1(x)) # communication over sequence length
        x = x + self.feed_forward(self.ln2(x)) # communication across embedding space
        return x

In [9]:
class TransformerLM(nn.Module):

    def __init__(self, vocab_size, context_window_size, embed_size=384, num_heads=6, n_layers=6):
        """
          Args:
              vocab_size: int, number of tokens in the vocabulary (V)
              context_window_size: int, size of the context window (T)
              embed_size: int, embedding size (D)
              num_heads: int, number of heads (H)
              n_layers: int, number of layers (M)
        """
        super().__init__()
        self.vocab_size = vocab_size
        self.context_window_size = context_window_size
        self.token_embedding_table = nn.Embedding(vocab_size, embed_size)
        self.position_embedding_table = nn.Embedding(context_window_size, embed_size)
        self.blocks = nn.Sequential(*[
            TransformerBlock(vocab_size,
                             context_window_size,
                             embed_size=embed_size,
                             num_heads=num_heads)
            for _ in range(n_layers)])

        # final layer norm
        self.ln_f = nn.LayerNorm(embed_size)
        self.lm_head = nn.Linear(embed_size, vocab_size)

        # good initialization
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, token_ids, targets=None):
        """
        Agrgs:
            token_ids: tensor of integers, provides the contet, shape (B, T)
            targets: tensor of integers, provides the tokens we are preidcitng, shape (B, T)
        """
        B, T = token_ids.shape

        # token_ids and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(token_ids) # (B, T, D)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, D)
        x = tok_emb + pos_emb # (B, T, D)

        # TODO: your code below (copied from previous code cell)
        logits = self.lm_head(self.ln_f(self.blocks(x)))
        
        # carbon copy (once again) from above
        if targets is None:
            loss = None
        else:
            # treat this as a classification problem - do cross entropy AVERAGED PER TOKEN!
            loss = F.cross_entropy(input=logits.reshape(-1, self.vocab_size), 
                                   target=targets.flatten()) # F.cross_entropy averages per token by default

        return logits, loss


    @torch.no_grad()
    def generate(self, token_ids, max_new_tokens):
        """
        Args:
            token_ids: tensor of integers forming the context, shape (B, T)
            max_new_tokens: int, max number of tokens to generate
        """
        # take in context, compute probabilities of next token, sample, repeat
        for t_pred in range(max_new_tokens):
            
            # get probability vectors given the current context
            logits, loss = self(token_ids[:,-self.context_window_size:])

            # sample our next token for each batch, augment to token_ids (our context)
            token_ids = torch.hstack(
                [token_ids, torch.multinomial(
                    input=torch.softmax(logits[:,-1,:], dim=1), num_samples=1)])
            
        # after we've finished generating our entire forecast horizon, output everything.
        return token_ids

In [10]:
# iterate thru our embedding sizes
for embed_size in [192, 384, 576, 768, 960]:

    # set a seed for "reproducibility"
    torch.manual_seed(4513215)
    
    # create a directory if necessary
    if "char_models" not in os.listdir():
        os.mkdir("char_models")
    foldername = f"embed-size={embed_size}" 
    if foldername not in os.listdir("char_models"):
        os.mkdir(f"char_models/{foldername}")

    # status update
    print(f"Training baseline model with embed_size={embed_size}.")

    # initialize the model
    trans = TransformerLM(
        vocab_size=vocab_size, context_window_size=CONTEXT_WINDOW_SIZE, 
        embed_size=embed_size, num_heads=6, n_layers=6)
    tlm = trans.to(device)

    # set a learning rate + optimizer
    learning_rate = 5e-4
    optimizer = torch.optim.AdamW(trans.parameters(), lr=learning_rate)
    eval_interval = 200

    # metrics to track over time
    loss_list, wallclock_list = [], []

    # let's train for our iterations
    for it in tqdm(range(SMALL_ITERS)):

        # start our timer
        start = time.time()

        # every once in a while evaluate the loss on train and val sets
        if it % eval_interval == 0 or it == SMALL_ITERS - 1:
            print(f"iteration {it}")
            losses = estimate_loss(tlm, EVAL_ITERS, CONTEXT_WINDOW_SIZE, device)
            print(f"step {it}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

        # sample a batch of data
        xb, yb = get_batch('train', CONTEXT_WINDOW_SIZE, device)

        # evaluate the loss
        logits, loss = tlm(xb, yb)
        loss_list.append(loss.detach().item())
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

        # end our timer
        end = time.time()
        wallclock_list.append(end - start)

    # save our model weights at the end + also save our logs
    torch.save(tlm, f"char_models/{foldername}/model.pth")
    logs = pd.DataFrame(data={"loss" : loss_list, "wallclock" : wallclock_list})
    logs.to_csv(f"char_models/{foldername}/logs.csv", index=False)

    # clear cuda cache
    torch.cuda.empty_cache()

Training baseline model with embed_size=192.


  0%|                                                                                                          | 0/2000 [00:00<?, ?it/s]

iteration 0
step 0: train loss 4.1844, val loss 4.1811


 10%|█████████▌                                                                                      | 200/2000 [01:55<11:39,  2.57it/s]

iteration 200
step 200: train loss 2.5046, val loss 2.5120


 20%|███████████████████▏                                                                            | 400/2000 [03:34<09:35,  2.78it/s]

iteration 400
step 400: train loss 2.2530, val loss 2.2891


 30%|████████████████████████████▊                                                                   | 600/2000 [05:13<07:45,  3.01it/s]

iteration 600
step 600: train loss 1.9578, val loss 2.0509


 40%|██████████████████████████████████████▍                                                         | 800/2000 [06:54<07:16,  2.75it/s]

iteration 800
step 800: train loss 1.7446, val loss 1.8990


 50%|███████████████████████████████████████████████▌                                               | 1000/2000 [08:32<05:28,  3.04it/s]

iteration 1000
step 1000: train loss 1.6173, val loss 1.7926


 60%|█████████████████████████████████████████████████████████                                      | 1200/2000 [10:14<04:42,  2.83it/s]

iteration 1200
step 1200: train loss 1.5252, val loss 1.7235


 70%|██████████████████████████████████████████████████████████████████▌                            | 1400/2000 [11:55<03:43,  2.68it/s]

iteration 1400
step 1400: train loss 1.4852, val loss 1.6939


 80%|████████████████████████████████████████████████████████████████████████████                   | 1600/2000 [13:36<02:26,  2.72it/s]

iteration 1600
step 1600: train loss 1.4159, val loss 1.6495


 90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 1800/2000 [15:15<01:20,  2.50it/s]

iteration 1800
step 1800: train loss 1.3848, val loss 1.6175


100%|██████████████████████████████████████████████████████████████████████████████████████████████▉| 1999/2000 [16:54<00:00,  2.81it/s]

iteration 1999
step 1999: train loss 1.3580, val loss 1.6113


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [17:27<00:00,  1.91it/s]


Training baseline model with embed_size=384.


  0%|                                                                                                          | 0/2000 [00:00<?, ?it/s]

iteration 0
step 0: train loss 4.1470, val loss 4.1541


 10%|█████████▌                                                                                      | 200/2000 [03:15<21:27,  1.40it/s]

iteration 200
step 200: train loss 2.4752, val loss 2.4924


 20%|███████████████████▏                                                                            | 400/2000 [06:26<21:31,  1.24it/s]

iteration 400
step 400: train loss 2.1835, val loss 2.2405


 30%|████████████████████████████▊                                                                   | 600/2000 [09:46<16:43,  1.39it/s]

iteration 600
step 600: train loss 1.8386, val loss 1.9636


 40%|██████████████████████████████████████▍                                                         | 800/2000 [13:11<15:20,  1.30it/s]

iteration 800
step 800: train loss 1.6485, val loss 1.8189


 50%|███████████████████████████████████████████████▌                                               | 1000/2000 [16:39<12:48,  1.30it/s]

iteration 1000
step 1000: train loss 1.5311, val loss 1.7399


 60%|█████████████████████████████████████████████████████████                                      | 1200/2000 [20:07<10:27,  1.27it/s]

iteration 1200
step 1200: train loss 1.4516, val loss 1.6739


 70%|██████████████████████████████████████████████████████████████████▌                            | 1400/2000 [23:36<07:53,  1.27it/s]

iteration 1400
step 1400: train loss 1.3963, val loss 1.6428


 80%|████████████████████████████████████████████████████████████████████████████                   | 1600/2000 [27:02<05:17,  1.26it/s]

iteration 1600
step 1600: train loss 1.3573, val loss 1.6071


 90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 1800/2000 [30:28<02:28,  1.35it/s]

iteration 1800
step 1800: train loss 1.3124, val loss 1.6058


100%|██████████████████████████████████████████████████████████████████████████████████████████████▉| 1999/2000 [33:56<00:00,  1.32it/s]

iteration 1999
step 1999: train loss 1.2798, val loss 1.5749


100%|███████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [34:52<00:00,  1.05s/it]


Training baseline model with embed_size=576.


  0%|                                                                                                          | 0/2000 [00:00<?, ?it/s]

iteration 0
step 0: train loss 4.3120, val loss 4.3173


 10%|█████████▌                                                                                      | 200/2000 [06:11<41:51,  1.40s/it]

iteration 200
step 200: train loss 2.4858, val loss 2.5008


 20%|███████████████████▏                                                                            | 400/2000 [12:19<29:19,  1.10s/it]

iteration 400
step 400: train loss 2.3170, val loss 2.3519


 30%|████████████████████████████▊                                                                   | 600/2000 [18:28<31:13,  1.34s/it]

iteration 600
step 600: train loss 2.0510, val loss 2.1365


 40%|██████████████████████████████████████▍                                                         | 800/2000 [24:40<27:57,  1.40s/it]

iteration 800
step 800: train loss 1.7524, val loss 1.9087


 50%|███████████████████████████████████████████████▌                                               | 1000/2000 [30:48<23:40,  1.42s/it]

iteration 1000
step 1000: train loss 1.5799, val loss 1.7736


 60%|█████████████████████████████████████████████████████████                                      | 1200/2000 [37:09<18:55,  1.42s/it]

iteration 1200
step 1200: train loss 1.4806, val loss 1.6954


 70%|██████████████████████████████████████████████████████████████████▌                            | 1400/2000 [43:27<13:48,  1.38s/it]

iteration 1400
step 1400: train loss 1.4044, val loss 1.6326


 80%|████████████████████████████████████████████████████████████████████████████                   | 1600/2000 [49:44<09:01,  1.35s/it]

iteration 1600
step 1600: train loss 1.3570, val loss 1.6143


 90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 1800/2000 [56:00<03:36,  1.08s/it]

iteration 1800
step 1800: train loss 1.3070, val loss 1.5910


100%|████████████████████████████████████████████████████████████████████████████████████████████▉| 1999/2000 [1:02:06<00:01,  1.38s/it]

iteration 1999
step 1999: train loss 1.2567, val loss 1.5967


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [1:03:48<00:00,  1.91s/it]


Training baseline model with embed_size=768.


  0%|                                                                                                          | 0/2000 [00:00<?, ?it/s]

iteration 0
step 0: train loss 4.3298, val loss 4.3423


 10%|█████████▍                                                                                    | 200/2000 [09:58<1:08:23,  2.28s/it]

iteration 200
step 200: train loss 2.5150, val loss 2.5282


 20%|███████████████████▏                                                                            | 400/2000 [19:54<59:53,  2.25s/it]

iteration 400
step 400: train loss 2.4256, val loss 2.4479


 30%|████████████████████████████▊                                                                   | 600/2000 [29:44<50:53,  2.18s/it]

iteration 600
step 600: train loss 2.3164, val loss 2.3570


 40%|██████████████████████████████████████▍                                                         | 800/2000 [39:05<33:37,  1.68s/it]

iteration 800
step 800: train loss 2.1011, val loss 2.1899


 50%|███████████████████████████████████████████████▌                                               | 1000/2000 [46:30<27:12,  1.63s/it]

iteration 1000
step 1000: train loss 1.8522, val loss 1.9750


 60%|█████████████████████████████████████████████████████████                                      | 1200/2000 [53:51<22:19,  1.67s/it]

iteration 1200
step 1200: train loss 1.6770, val loss 1.8462


 70%|█████████████████████████████████████████████████████████████████                            | 1400/2000 [1:01:22<16:59,  1.70s/it]

iteration 1400
step 1400: train loss 1.5482, val loss 1.7451


 80%|██████████████████████████████████████████████████████████████████████████▍                  | 1600/2000 [1:08:53<10:54,  1.64s/it]

iteration 1600
step 1600: train loss 1.4678, val loss 1.6682


 90%|███████████████████████████████████████████████████████████████████████████████████▋         | 1800/2000 [1:16:23<05:40,  1.70s/it]

iteration 1800
step 1800: train loss 1.4031, val loss 1.6236


100%|████████████████████████████████████████████████████████████████████████████████████████████▉| 1999/2000 [1:23:34<00:01,  1.70s/it]

iteration 1999
step 1999: train loss 1.3485, val loss 1.5940


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [1:25:28<00:00,  2.56s/it]


Training baseline model with embed_size=960.


  0%|                                                                                                          | 0/2000 [00:00<?, ?it/s]

iteration 0
step 0: train loss 4.3867, val loss 4.3844


 10%|█████████▍                                                                                    | 200/2000 [11:14<1:16:49,  2.56s/it]

iteration 200
step 200: train loss 2.5270, val loss 2.5278


 20%|██████████████████▊                                                                           | 400/2000 [22:27<1:08:10,  2.56s/it]

iteration 400
step 400: train loss 2.4512, val loss 2.4802


 30%|████████████████████████████▏                                                                 | 600/2000 [33:31<1:00:16,  2.58s/it]

iteration 600
step 600: train loss 2.4529, val loss 2.4743


 40%|██████████████████████████████████████▍                                                         | 800/2000 [44:43<51:09,  2.56s/it]

iteration 800
step 800: train loss 2.4904, val loss 2.5099


 50%|███████████████████████████████████████████████▌                                               | 1000/2000 [55:52<29:27,  1.77s/it]

iteration 1000
step 1000: train loss 2.5265, val loss 2.5399


 60%|█████████████████████████████████████████████████████████                                      | 1200/2000 [59:30<10:38,  1.25it/s]

iteration 1200
step 1200: train loss 2.5205, val loss 2.5415


 70%|█████████████████████████████████████████████████████████████████                            | 1400/2000 [1:02:56<07:48,  1.28it/s]

iteration 1400
step 1400: train loss 2.4705, val loss 2.4931


 80%|██████████████████████████████████████████████████████████████████████████▍                  | 1600/2000 [1:06:23<05:10,  1.29it/s]

iteration 1600
step 1600: train loss 2.4866, val loss 2.5115


 90%|███████████████████████████████████████████████████████████████████████████████████▋         | 1800/2000 [1:09:50<02:36,  1.28it/s]

iteration 1800
step 1800: train loss 2.4484, val loss 2.4768


100%|████████████████████████████████████████████████████████████████████████████████████████████▉| 1999/2000 [1:13:16<00:00,  1.29it/s]

iteration 1999
step 1999: train loss 2.5440, val loss 2.5713


100%|█████████████████████████████████████████████████████████████████████████████████████████████| 2000/2000 [1:14:08<00:00,  2.22s/it]
