The codes are adpated from https://github.com/karpathy/nanoGPT/tree/master

In [None]:
from pathlib import Path
import requests
import json
import time
from typing import Final
import math
# 
import numpy as np
# torch
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (
    IterableDataset,
    DataLoader,
)
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.nn.utils import clip_grad_norm_

In [2]:
class TinyShakespeareDataset(IterableDataset):

    DATA_URL = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'

    TEXT_FILENAME = 'input.txt'
    TRAIN_SET_FILENAME = 'train.pt'
    VAL_SET_FILENAME = 'val.pt'
    METADATA_FILENAME = 'metadata.json'

    START_TOKEN: Final[str] = '\n'
    
    def __init__(
        self,
        root: Path,
        block_size: int,
        train: bool,
    ) -> None:
        """
        """
        super().__init__()
        if not root.exists():
            self.download(root)

        path = root / (self.TRAIN_SET_FILENAME if train else self.VAL_SET_FILENAME)
        

        self.data = torch.load(path, weights_only=True)
        self.block_size = block_size


        metadata_path = root / self.METADATA_FILENAME
        with open(metadata_path, 'r') as stream:
            char_to_idx = json.load(stream)
        
        self.char_to_idx = char_to_idx
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)

    
    def __iter__(self):
        start_high = len(self.data) - block_size
        
        while True:
            start = torch.randint(high=start_high, size=(1, ))[0]
            stop = start + self.block_size
            x = self.data[start: stop]
            y = self.data[start + 1: stop + 1]
            yield x, y


    @classmethod
    def download(cls, root: Path):
        if not root.exists():
            root.mkdir(parents=True)
        
        text = requests.get(cls.DATA_URL).text
        
        with open((root / cls.TEXT_FILENAME), 'w') as stream:
            stream.write(text)

        # metadata
        char_set = sorted(list(set(text)))
        char_to_idx = {char: idx for idx, char in enumerate(char_set)}

        dataset = cls.encode(text, char_to_idx)
        dataset = torch.tensor(dataset, dtype=torch.int64)
        
        train_stop = int(0.9 * len(dataset))
        train_set = dataset[:train_stop]
        val_set = dataset[train_stop:]

        # 
        torch.save(train_set, root / cls.TRAIN_SET_FILENAME)
        torch.save(val_set, root / cls.VAL_SET_FILENAME)

        # save metadata
        metadata_path = root / cls.METADATA_FILENAME
        with open(metadata_path, 'w') as stream:
            json.dump(char_to_idx, stream, indent=4)

    @classmethod
    def _encode(cls, char_list: str | list[str], char_to_idx) -> list[int]:
        return [char_to_idx[char] for char in char_list]

    def encode(self, char_list: str | list[str]):
        return self._encode(char_list, self.char_to_idx)

    @classmethod
    def _decode(cls, idx_list: list[int], idx_to_char: dict[int, str]) -> list[str]:
        return [idx_to_char[idx] for idx in idx_list]

    def decode(self, idx_list: list[int]):
        return self._decode(idx_list, self.idx_to_char)

# Model & Optimizers

In [3]:
class CrossAttention(nn.Module):
    """scaled dot product attention
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        bias: bool,
        dropout: float,
    ) -> None:
        """
        """
        super().__init__()
        
        assert embed_dim % num_heads == 0

        #
        self.query_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.key_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.value_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        # output projection
        self.output_projection = nn.Linear(embed_dim, embed_dim, bias=bias)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.output_dropout = nn.Dropout(dropout)

        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.dropout = dropout
        

    def forward(
        self,
        target: Tensor,
        source: Tensor,
        attn_mask: Tensor,
    ) -> Tensor:
        """
        Args:
            target:
            source:
            attn_mask:
        Returns:
            ...
        """
        # N: batch size, T: target sequence length, C: embedding dimensionality
        N, T, E = target.size()
        # S: source sequence length
        S = source.size(1)
        H = self.num_heads
        # depth
        D = E // H

        scale = 1.0 / math.sqrt(D)

        q  = self.query_proj(target)
        k  = self.key_proj(source)
        v  = self.value_proj(source)

        q = q.view(N, T, H, D).transpose(1, 2) # (N, H, T, D)
        k = k.view(N, S, H, D).transpose(1, 2) # (N, H, S, D)
        v = v.view(N, S, H, D).transpose(1, 2) # (N, H, S, D)

        attn = (q @ k.transpose(-2, -1)) * scale
        attn = attn.masked_fill(
            mask=attn_mask, 
            value=float('-inf'),
        )
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        y = attn @ v
        y = y.transpose(1, 2).contiguous().view(N, T, E)

        # output projection
        y = self.output_projection(y)
        y = self.output_dropout(y)
        return y

In [4]:
class SelfAttention(CrossAttention):

    def forward(self, input: Tensor, attn_mask: Tensor) -> Tensor:
        return super().forward(source=input, target=input, attn_mask=attn_mask)

In [5]:
class MLP(nn.Sequential):

    def __init__(
        self, 
        embed_dim: int,
        bias: bool,
        dropout: float,
        widening_factor: int = 4,
    ) -> None:
        """
        """
        hidden_dim = 4 * embed_dim
        
        super().__init__(
            nn.Linear(
                in_features=embed_dim, 
                out_features=hidden_dim, 
                bias=bias
            ),
            nn.GELU(),
            nn.Linear(
                in_features=hidden_dim, 
                out_features=embed_dim, 
                bias=bias
            ),
            nn.Dropout(p=dropout),
        )

In [6]:
class Block(nn.Module):

    def __init__(
        self,
        embed_dim: int,
        bias: bool,
        dropout: float,
        widening_factor: int,
        num_heads: int,
        block_size: int,
    ) -> None:
        """
        """
        super().__init__()
        self.ln_1 = nn.LayerNorm(
            normalized_shape=embed_dim, 
            bias=bias,
        )
        self.attn = SelfAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            bias=bias,
            dropout=dropout,
        )
        self.ln_2 = nn.LayerNorm(
            normalized_shape=embed_dim, 
            bias=bias,
        )
        self.mlp = MLP(
            embed_dim=embed_dim,
            bias=bias,
            dropout=dropout,
            widening_factor=widening_factor,
        )

    def forward(
        self,
        x,
        attn_mask: Tensor,
    ):
        x = x + self.attn(self.ln_1(x), attn_mask=attn_mask)
        x = x + self.mlp(self.ln_2(x))
        return x

In [7]:
class PicoGPT(nn.Module):

    causal_mask: Tensor

    def __init__(
        self,
        vocab_size: int,
        embed_dim: int,
        num_heads: int,
        block_size: int,
        dropout: float,
        num_layers: int,
        bias: bool,
        widening_factor: int,
    ):
        super().__init__()

        self.block_size = block_size
        

        block_config = dict(
            embed_dim=embed_dim,
            bias=bias,
            dropout=dropout,
            num_heads=num_heads,
            block_size=block_size,
            widening_factor=widening_factor,
        )

        
        # transformer
        self.token_embedder = nn.Embedding(
            num_embeddings=vocab_size, 
            embedding_dim=embed_dim,
        )
        self.position_embedder = nn.Embedding(
            num_embeddings=block_size, 
            embedding_dim=embed_dim,
        )
        self.dropout = nn.Dropout(p=dropout)
        self.block_list = nn.ModuleList([Block(**block_config) for _ in range(num_layers)])

        # token prediction head
        self.head = nn.Sequential(
            nn.LayerNorm(
                normalized_shape=embed_dim, 
                bias=bias,
            ),
            nn.Linear(
                in_features=embed_dim, 
                out_features=vocab_size, 
                bias=False,
            ),
        )
        # NOTE: https://paperswithcode.com/method/weight-tying
        self.token_embedder.weight = self.head[1].weight 

        self.register_buffer(
            name='causal_mask',
            tensor=torch.ones(1, 1, block_size, block_size, dtype=torch.int8).tril().eq(0),
        )

    def forward(
        self, 
        input: Tensor, 
        target: Tensor | None = None,
    ) -> tuple[Tensor, Tensor | None]:
        """
        Args:
            input (Tensor): input tokens
            target (Tensor or None): target tokens
        Returns:
            logits:
            loss: (Tensor or None)
        """
        # N: batch size, L: max length in a batch
        N, L = input.size()
        device = input.device

        position = torch.arange(0, L, dtype=torch.int64, device=input.device)
        attn_mask = self.causal_mask[..., :L, :L]

        token_embed = self.token_embedder(input)
        position_embed = self.position_embedder(position)
        embed = token_embed + position_embed
        embed = self.dropout(embed)

        for block in self.block_list:
            embed = block(embed, attn_mask=attn_mask)

        if target is not None:
            # training
            logits = self.head(embed)
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)), 
                target.view(-1), 
                ignore_index=-1,
            )
        else:
            # inferences
            # NOTE: using list [-1] to preserve the time dim
            logits = self.head(embed[:, [-1], :])
            loss = None

        return logits, loss


    @torch.no_grad()
    def generate(self, idx: Tensor | None, max_new_tokens, temperature=1.0, top_k=None):        
        for _ in range(max_new_tokens):
            # if the sequence context is growing too long we must crop it at block_sizes
            idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
            # forward the model to get the logits for the index in the sequence
            logits, _ = self(idx_cond)
            # pluck the logits at the final step and scale by desired temperature
            logits = logits[:, -1, :] / temperature
            # optionally crop the logits to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float('Inf')
            # apply softmax to convert logits to (normalized) probabilities
            probs = F.softmax(logits, dim=-1)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)
            # append sampled index to the running sequence and continue
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

# Hyperparameters

In [None]:
# model
block_size = 256
batch_size = 64
grad_clip_value = 1
embed_dim = 384
num_heads = 6
num_layers = 6
bias = False
widening_factor = 4
dropout = 0

# optimizer
lr = 3.0e-4
betas = (0.9, 0.99)

# training
max_training_steps = 5_000
val_interval = 250
num_val_batches = 200

# sampling
max_new_tokens = 500
temperature = 0.8
top_k = 200

# system
dataset_root = Path('./data')
device = torch.device('cuda:1')
seed = 1337

In [None]:
# setup
torch.manual_seed(seed)

# Data

In [None]:
dataset_kwargs = dict(
    root=dataset_root,
    block_size=block_size,
)

train_set = TinyShakespeareDataset(train=True, **dataset_kwargs)
val_set = TinyShakespeareDataset(train=False, **dataset_kwargs)

train_loader = DataLoader(train_set, batch_size=batch_size)
train_iter = iter(train_loader)

val_loader = DataLoader(val_set, batch_size=batch_size)
val_iter = iter(val_loader)

# Model

In [11]:
model = PicoGPT(
    vocab_size=train_set.vocab_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    block_size=block_size,
    dropout=dropout,
    num_layers=num_layers,
    bias=bias,
    widening_factor=widening_factor,
)

model = model.to(device)

In [12]:
optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=lr,
    betas=betas,
)

# Training

In [None]:
@torch.no_grad()
def validate(
    model: PicoGPT,
    data_iter, 
    num_batches: int,
    device: torch.device
) -> float:
    """
    """
    model.eval()

    loss_sum = 0
    for _ in range(num_batches):
        x, y = next(data_iter)
        x, y = x.to(device), y.to(device)
        _, loss = model(x, y)
        loss_sum += loss.item()
    return loss_sum / num_batches

In [14]:
def train(
    model: PicoGPT,
    data_iter,
    optimizer,
    device: torch.device,
    grad_clip_value: float,
) -> None:
    """single training step
    """
    model.train()
    
    x, y = next(data_iter)
    x, y = x.to(device), y.to(device)

    optimizer.zero_grad(set_to_none=True)
    _, loss = model(input=x, target=y)
    clip_grad_norm_(model.parameters(), grad_clip_value)
    loss.backward()
    optimizer.step()    

In [None]:
validate_kwargs = dict(
    model=model,
    num_batches=num_val_batches,
    device=device,
)

start_id = train_set.encode('\n')
x_start = torch.tensor(start_id, dtype=torch.int64, device=device).unsqueeze(0)


start_time = time.time()
for step in range(0, max_training_steps + 1):
    if step > 0:
        train(
            model=model,
            data_iter=train_iter,
            optimizer=optimizer,
            device=device,
            grad_clip_value=grad_clip_value,
        )

    if (step % val_interval == 0):
        train_loss = validate(data_iter=train_iter, **validate_kwargs)
        val_loss = validate(data_iter=val_iter, **validate_kwargs)
        elapsed_time = (time.time() - start_time) / 60
        y = model.generate(
            x_start, 
            max_new_tokens, 
            temperature=temperature,
            top_k=top_k,
        )
        
    print(f'🚀🚀🚀 {step=: >12_d} ({elapsed_time=:.1f} min): {train_loss=:.6f} {val_loss=:.6f}')
        print(''.join(train_set.decode(y.tolist()[0])))
        print('=' * 80)