# SanGuo GPT Notebook

In [1]:
# Download the text for training.
!wget https://raw.githubusercontent.com/naosense/Yiya/master/book/%E4%B8%89%E5%9B%BD%E6%BC%94%E4%B9%89.txt -O sanguo.txt

# The text is encoded in GBK, let's convert it to UTF-8.
!iconv -f GBK -t UTF-8 sanguo.txt > sanguo-utf8.txt

--2023-08-24 09:31:23--  https://raw.githubusercontent.com/naosense/Yiya/master/book/%E4%B8%89%E5%9B%BD%E6%BC%94%E4%B9%89.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1206396 (1.2M) [text/plain]
Saving to: ‘sanguo.txt’


2023-08-24 09:31:24 (2.56 MB/s) - ‘sanguo.txt’ saved [1206396/1206396]



In [2]:
# !sed -i '/^=/,/^=/d' sanguo-utf8.txt
# For Mac:
!sed -i "" '/^=/,/^=/d' sanguo-utf8.txt

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
with open('sanguo-utf8.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [5]:
print(f"Length of text: {len(text)}")   # 606051 Chinese characters
print(text[:100])

Length of text: 606051
《三国演义》（精校版全本）作者：罗贯中
 

内容简介

　　《三国演义》由东汉末年黄巾起义末期开始描写，至西晋初期国家重归统一结束，以魏、蜀、吴三个政治、军事集团之间的形成演变，矛盾斗争为主线，最后


In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [7]:
# Vocabulary of the book - less than 4000 Chinese characters.
# See? Chinese is not that hard - you only need to know a few thousand and
# you'll be able to read the ancient Chinese classics!
print(vocab_size)

3952


In [8]:
# I don't plan to use a tokenizer for this.
# IMHO, a Chinese character is not a "letter". Instead it's more like
# a word or subword. So we should treat each Chinese character as a token.

# Turn each character into a number (index into the chars array)
# Map character to index.
c2i = {ch:i for i, ch in enumerate(chars)}
# Map index to character.
i2c = {i:ch for i, ch in enumerate(chars)}

# Given a string (sequence of characters), encode it into a sequence of indices.
encoder = lambda s: [c2i[c] for c in s]
# Given a sequence of indices, decode it back to the string
decoder = lambda l: ''.join([i2c[i] for i in l])

In [9]:
print("Original text:")
print(text[:50])

print("\nEncoded:")
print(encoder(text[:50]))

print("\nDecoded:")
print(decoder(encoder(text[:50])))

Original text:
《三国演义》（精校版全本）作者：罗贯中
 

内容简介

　　《三国演义》由东汉末年黄巾起义末期开始

Encoded:
[16, 25, 655, 2054, 61, 17, 3946, 2583, 1732, 2143, 274, 1657, 3947, 174, 2734, 3949, 2694, 3249, 48, 0, 1, 0, 0, 292, 886, 2548, 118, 0, 0, 13, 13, 16, 25, 655, 2054, 61, 17, 2270, 40, 1876, 1656, 1039, 3919, 1005, 3290, 61, 1656, 1652, 1082, 796]

Decoded:
《三国演义》（精校版全本）作者：罗贯中
 

内容简介

　　《三国演义》由东汉末年黄巾起义末期开始


In [10]:
# Some hyperparameters
block_size = 192
batch_size = 16
d_model = 384
n_head = 8
n_layer = 6
dropout = 0.0
lr_rate = 1e-3
max_iters = 10000
eval_interval = 100
eval_iters = 20
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


In [11]:
data = torch.tensor(encoder(text), dtype=torch.long)
print(data.shape, data.dtype)
# data = encoder(text)

# Split up into training and validation sets.
# Generate a random permutation of the entire dataset.
# Sequence of length <block_size> is used to predict the next token
# The last seq will be [len(data) - block_size - 1, len(data) - 2] (inclusive).
# The last next token to be predicted will be <len(data) - 1>.
# So the index won't be out of bound.
perm = torch.randperm(len(data) - block_size)
# Then first 90% are training data, and rest are for validation.
n = int(0.9 * len(perm))
# We only save the start position of each example instead of the entire
# sequence. The sequence will be generated when creating the batches.
train_indices = perm[:n]
val_indices = perm[n:]

# This won't work - consumes too much memory
# training_set = [data[perm[i]:perm[i]+T] for i in perm[:n]]

torch.Size([606051]) torch.int64


In [12]:
# My simple dataloader. TODO: Can we use DataLoader instead?
def get_batch(split):
    # select training or validation set
    indices = train_indices if split == 'train' else val_indices
    # train_indices/val_indices stores the start locations
    # pick <batch_size> number of them from the array
    ix = torch.randint(len(indices), (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

In [13]:
class Head(nn.Module):
  """ one head of self-attention """

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(d_model, head_size, bias=False)
    self.query = nn.Linear(d_model, head_size, bias=False)
    self.value = nn.Linear(d_model, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape
    k = self.key(x)   # (B,T,C)
    q = self.query(x) # (B,T,C)
    # compute attention scores ("affinities")
    wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
    # apply the "causal mask"
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
    wei = F.softmax(wei, dim=-1) # (B, T, T)
    wei = self.dropout(wei)
    # perform the weighted aggregation of the values
    v = self.value(x) # (B,T,C)
    out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
    return out

class MultiHeadAttention(nn.Module):
  """ multiple heads of self-attention in parallel """

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(d_model, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    # print("MHA output shape", out.shape)
    return out # (B, T, n_embed)

class FeedFoward(nn.Module):
  """ a simple linear layer followed by a non-linearity """

  def __init__(self, d_model):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.ReLU(),
        nn.Linear(4 * d_model, d_model),
        nn.Dropout(dropout),
    )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  """ Transformer block: communication followed by computation """

  def __init__(self, d_model, n_head):
    # d_model: embedding dimension, n_head: the number of heads we'd like
    super().__init__()
    head_size = d_model // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ffwd = FeedFoward(d_model)
    self.ln1 = nn.LayerNorm(d_model)
    self.ln2 = nn.LayerNorm(d_model)

  def forward(self, x):
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    # print("Block output shape", x.shape)
    return x   # (B, T, n_embed)

# super simple bigram model
class SanGuoGPTModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, d_model)
        self.position_embedding_table = nn.Embedding(block_size, d_model)
        self.blocks = nn.Sequential(*[Block(d_model, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d_model) # final layer norm
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx


In [14]:
model = SanGuoGPTModel()
m = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

13.753456 M parameters


In [15]:
@torch.no_grad()
def estimate_loss():
  out = {}
  # set the model in evaluation mode
  model.eval()
  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      logits, loss = model(X, Y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  # switch back to training mode
  model.train()
  return out

In [16]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=lr_rate)
optimizer.zero_grad(set_to_none=True)

for iter in range(max_iters):
  # every once in a while evaluate the loss on train and val sets
  if ((iter+1) % eval_interval) == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  # sample a batch of data
  xb, yb = get_batch('train')

  # evaluate the loss
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()


step 99: train loss 5.1277, val loss 5.3283
step 199: train loss 4.7032, val loss 4.8544
step 299: train loss 4.4147, val loss 4.6278
step 399: train loss 4.2018, val loss 4.3512
step 499: train loss 4.0151, val loss 4.1398
step 599: train loss 3.8234, val loss 4.0403
step 699: train loss 3.7339, val loss 3.8090
step 799: train loss 3.5598, val loss 3.7358
step 899: train loss 3.4869, val loss 3.5947
step 999: train loss 3.3062, val loss 3.4453
step 1099: train loss 3.2347, val loss 3.4102
step 1199: train loss 3.1301, val loss 3.2392
step 1299: train loss 2.9740, val loss 3.0964
step 1399: train loss 2.8543, val loss 3.0026
step 1499: train loss 2.8045, val loss 2.8930
step 1599: train loss 2.5893, val loss 2.8269
step 1699: train loss 2.5338, val loss 2.6030
step 1799: train loss 2.4659, val loss 2.6086
step 1899: train loss 2.3565, val loss 2.4526
step 1999: train loss 2.2901, val loss 2.3496
step 2099: train loss 2.1801, val loss 2.2719
step 2199: train loss 2.0824, val loss 2.1367

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decoder(m.generate(context, max_new_tokens=1000)[0].tolist()))

In [18]:
torch.save(model, "sanguogpt-v0.1.pth")

# Complete Code

## Data

In [1]:
import torch
import json

class SanGuoData:
    def __init__(self, source = 'sanguo-utf8.txt', block_size = 192, training_set_ratio = 0.9):
        self.source = source
        self.block_size = block_size
        self.training_set_ratio = training_set_ratio
        self.text = None
        self.chars = None
        self.vocab_size = 0
        self.c2i = None
        self.i2c = None
        self.encoder = None
        self.decoder = None
        self.data = None
    
    def ingest(self, gen_dataset=True, gen_token_map=True):
        with open(self.source, 'r', encoding='utf-8') as f:
            self.text = f.read()
        print(f"Length of text: {len(self.text)}")   # 606051 Chinese characters
        # print(self.text[:100])
        self.chars = sorted(list(set(self.text)))
        self.vocab_size = len(self.chars) 
        print(f"Vocabulary size: {self.vocab_size}")

        # I don't plan to use a tokenizer for this.
        # IMHO, a Chinese character is not a "letter". Instead it's more like
        # a word or subword. So we should treat each Chinese character as a token.

        # Turn each character into a number (index into the chars array)
        # Map character to index.
        self.c2i = {ch:i for i, ch in enumerate(self.chars)}
        # Map index to character.
        self.i2c = {i:ch for i, ch in enumerate(self.chars)}

        # Given a string (sequence of characters), encode it into a sequence of indices.
        self.encoder = lambda s: [self.c2i[c] for c in s]
        # Given a sequence of indices, decode it back to the string
        self.decoder = lambda l: ''.join([self.i2c[i] for i in l])

        self.data = torch.tensor(self.encoder(self.text), dtype=torch.long)
        # print(self.data.shape, self.data.dtype)

        if gen_token_map:
            self.save_token_map()

        if gen_dataset:
            self.gen_dataset()
    
    def save_token_map(self, c2i_file:str = 'c2i.json', i2c_file:str='i2c.json'):
        with open(c2i_file, 'w', encoding='utf-8') as f:
            json.dump(self.c2i, f)
        with open(i2c_file, 'w', encoding='utf-8') as f:
            json.dump(self.i2c, f)

    def test_enc_dec(self):
        print("Original text:")
        print(self.text[:50])

        print("\nEncoded:")
        print(self.encoder(self.text[:50]))

        print("\nDecoded:")
        print(self.decoder(self.encoder(self.text[:50])))

    def gen_dataset(self):
        # Split up into training and validation sets.
        # Generate a random permutation of the entire dataset.
        # Sequence of length <block_size> is used to predict the next token
        # The last seq will be [len(data) - block_size - 1, len(data) - 2] (inclusive).
        # The last next token to be predicted will be <len(data) - 1>.
        # So the index won't be out of bound.
        self.perm = torch.randperm(len(self.data) - self.block_size)
        # Then first 90% are training data, and rest are for validation.
        n = int(self.training_set_ratio * len(self.perm))
        # We only save the start position of each example instead of the entire
        # sequence. The sequence will be generated when creating the batches.
        self.train_indices = self.perm[:n]
        self.val_indices = self.perm[n:]
        self.train_batchptr = 0
        self.val_batchptr = 0
    
    # If `random` is True, we randomly pick batch_size items from the set.
    # But since training_indices/val_indices are already shuffled, this is not really
    # needed.
    def get_batch(self, split:str, batch_size, device, random=False):
        # select training or validation set
        indices = self.train_indices if split == 'train' else self.val_indices
        ptr = self.train_batchptr if split == 'train' else self.val_batchptr

        # train_indices/val_indices stores the start locations
        if random:
            ix = torch.randint(len(indices), (batch_size,))
        else:
            # The train/val set is already shuffled, so we just need to sequentially
            # go through the items batch by batch.
            next = ptr + batch_size
            if next < len(indices):
                ix = indices[ptr:next]
            else:
                # Handle the case when we wrap around the list.
                next = next % len(indices)
                ix = torch.cat((indices[ptr:len(indices)], indices[0:next]))
            # Move the batch pointer
            if split == 'train':
                self.train_batchptr = next
            else:
                self.val_batchptr = next
        # Generate the actual examples & labels for the batch.        
        x = torch.stack([self.data[i:i+self.block_size] for i in ix])
        y = torch.stack([self.data[i+1:i+self.block_size+1] for i in ix])
        x, y = x.to(device), y.to(device)
        return x, y


def encoder(s:str, c2i:dict):
    return [c2i[c] for c in s] 

def decoder(l, i2c:dict):
    return ''.join([i2c[i] for i in l]) 

def load_token_map(c2i_file:str = 'c2i.json', i2c_file:str='i2c.json'):
    # Load token map from the file.
    with open('c2i.json', 'r', encoding='utf-8') as f:
        c2i = json.load(f)

    # When loaded from JSON, the keys will become strings (e.g. '3913': '麒' instead of 3913: '麒')
    with open('i2c.json', 'r', encoding='utf-8') as f:
        i2c_raw = json.load(f)
        # Convert the keys to integers.
        i2c = {int(i):i2c_raw[i] for i in i2c_raw.keys()} 
    
    return c2i, i2c

# Test load/save token map.
def test_token_map():
    data = SanGuoData()
    data.ingest()
    # Save the token map to json files.
    data.save_token_map()

    # Load the token map and test encoding/decoding with it.
    c2i, i2c = load_token_map()

    print("Original text:")
    print(data.text[50:100])

    print("\nEncoded:")
    print(encoder(data.text[50:100], c2i))

    print("\nDecoded:")
    print(decoder(encoder(data.text[50:100], c2i), i2c))

## Model

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# v0.1:
# Super simple, hand-written Transformer model. 
# Inspired by the "Building a GPT" notebook by Andrej Karpathy.
# https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size, d_model, block_size, dropout):
        super().__init__()
        self.key = nn.Linear(d_model, head_size, bias=False)
        self.query = nn.Linear(d_model, head_size, bias=False)
        self.value = nn.Linear(d_model, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        # apply the "causal mask"
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out

class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, n_head, head_size, d_model, dropout, block_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size=head_size,
                                        d_model=d_model,
                                        block_size=block_size,
                                        dropout=dropout) for _ in range(n_head)])
        self.proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        # print("MHA output shape", out.shape)
        return out # (B, T, n_embed)

class FeedFoward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, d_model, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, d_model),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)

class Block(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, d_model, n_head, dropout, block_size):
        # d_model: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = d_model // n_head
        self.sa = MultiHeadAttention(n_head=n_head,
                                    head_size=head_size,
                                    d_model=d_model,
                                    dropout=dropout,
                                    block_size=block_size)
        self.ffwd = FeedFoward(d_model=d_model, dropout=dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        # print("Block output shape", x.shape)
        return x   # (B, T, n_embed)

# super simple model
class SanGuoGPTModel(nn.Module):

    def __init__(self, vocab_size, d_model, n_layer, dropout, block_size, n_head, device):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, d_model)
        self.position_embedding_table = nn.Embedding(block_size, d_model)
        self.blocks = nn.Sequential(*[Block(d_model=d_model,
                                            n_head=n_head,
                                            dropout=dropout,
                                            block_size=block_size) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d_model) # final layer norm
        self.lm_head = nn.Linear(d_model, vocab_size)
        self.device = device
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.d_model = d_model
        self.n_layer = n_layer
        self.dropout = dropout
        self.n_head = n_head

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=self.device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)   # loss is a scalar

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        # perplexity = exp(-1/N sum(log(p(w_i|w_1,...,w_{i-1}))))
        sum_log_p = torch.zeros(idx.shape[0])   # shape (B,)
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -self.block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # calculate perplexity: sum of log_p
            log_p = torch.tensor([probs[i,idx_next[i,0].item()] for i in range(idx_next.shape[0])])
            log_p = torch.log(log_p)
            sum_log_p += log_p
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        
        # Return the generated text along with perplexity.
        return idx, torch.exp(-1.0 * sum_log_p / max_new_tokens)
    
    @torch.no_grad()
    def get_embeddings(self, tokens):
        return self.token_embedding_table(tokens)
       

## Training

In [5]:
import argparse
import math
import torch
import os
import datetime
import time

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter

args = argparse.Namespace()
args.input = 'sanguo-utf8.txt'
args.output = 'sanguogpt.pth'
args.no_save_model = False
args.batch_size = 32
args.block_size = 256
args.d_model = 384
args.num_heads = 8
args.num_layers = 6
args.dropout = 0.01
args.lr_rate = 6e-4
args.min_lr = 6e-5
args.lr_decay_iters = 60000
args.warmup_iters = 1000
args.decay_lr = True
args.num_iters = 1000
args.eval_interval = 100
args.eval_iters = 10
args.training_set_ratio = 0.9

@torch.no_grad()
def estimate_loss(model, data, device):
    out = {}
    eval_iters = args.eval_iters
    batch_size = args.batch_size
    # set the model in evaluation mode
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            # For validation, we randomly pick a few items as a batch.
            X, Y = data.get_batch(split, batch_size=batch_size, device=device, random=True)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    # switch back to training mode
    model.train()
    return out

def train(session_name:str = None):
    # Some hyperparameters
    block_size = args.block_size
    batch_size = args.batch_size
    d_model = args.d_model
    n_head = args.num_heads
    n_layer = args.num_layers
    dropout = args.dropout
    lr_rate = args.lr_rate
    max_iters = args.num_iters
    eval_interval = args.eval_interval
    training_set_ratio = args.training_set_ratio
    device = (
        "cuda" if torch.cuda.is_available()
        else "mps" if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")

    torch.manual_seed(1337)

    if session_name is None:
        writer = None
    else:
        writer = SummaryWriter(os.path.join('runs', session_name))

    # Prepare the dataset
    sanguo_data = SanGuoData(source = args.input, block_size = block_size, training_set_ratio = training_set_ratio)
    sanguo_data.ingest()
    print(f"Number of tokens in each batch: {block_size*batch_size}")

    # Create the model
    model = SanGuoGPTModel(vocab_size=sanguo_data.vocab_size,
                        d_model=d_model,
                        n_layer=n_layer,
                        dropout=dropout,
                        block_size=block_size,
                        n_head=n_head,
                        device=device
                        )
    m = model.to(device)
    # print the number of parameters in the model
    print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

    # Visualize the model
    xb, yb = sanguo_data.get_batch('train', batch_size=batch_size, device=device, random=True)
    if writer is not None:
        writer.add_graph(model, (xb, yb))
        writer.flush()

    # For visualizing the embeddings:
    # encoder returns a list for each string (which is a single character in vocabulary).
    # Therefore, the shape of all_token will be like (vocab_size, 1).
    # We want a 1-D list, so we squeeze the last dimension and change the shape to (vocab_size, ).
    all_tokens = torch.tensor([sanguo_data.encoder(ch) for ch in sanguo_data.chars[20:-7]],
                            dtype=torch.long, device=device, requires_grad=False).squeeze(1)
    print('tokens for visualization', all_tokens.shape)

    # create a PyTorch optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr_rate)
    optimizer.zero_grad(set_to_none=True)

    training_start = time.time()
    # Training loop
    for iter in range(max_iters):
        # every once in a while evaluate the loss on train and val sets
        if (iter % eval_interval) == 0 or iter == max_iters - 1:
            losses = estimate_loss(model, sanguo_data, device)
            # perplexity = exp(cross_entropy)
            ppl_train = math.exp(losses['train'].item())
            ppl_val = math.exp(losses['val'].item())
            print(f"step {iter}: train loss {losses['train']:.3f}, perplexity {ppl_train:.3f}, val loss {losses['val']:.3f}, perplexity {ppl_val:.3f}")
            # Log the estimated training loss and validation loss
            if writer is not None:
                writer.add_scalars(
                    # main_tag
                    'Estimated Training vs. Validation Loss',
                    # tag_scalar
                    {
                        'Training': losses['train'].item(),
                        'Validation': losses['val'].item(),
                    },
                    # global_step
                    iter)
                writer.add_scalars(
                    # main_tag
                    'Estimated Training vs. Validation Perplexity',
                    # tag_scalar
                    {
                        'Training': ppl_train,
                        'Validation': ppl_val,
                    },
                    # global_step
                    iter)
                # Visualize the embeddings.
                embedding_table = model.get_embeddings(all_tokens)  # (vocab_size, d_model)
                # print(embedding_table.shape)
                writer.add_embedding(embedding_table, metadata=sanguo_data.chars[20:-7], tag=f"embeddings-step{iter}")
                writer.flush()

        # sample a batch of data
        xb, yb = sanguo_data.get_batch('train', batch_size=batch_size, device=device)

        # evaluate the loss
        logits, loss = model(xb, yb)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # if (iter % eval_interval) == 0 or iter == max_iters - 1:
        #     loss_val = loss.item()
        #     print(f"iteration: {iter:>6d}, loss: {loss_val:>7f}")

    training_end = time.time()
    print("Finished training")
    print(f"Total number of tokens trained: {block_size*batch_size*max_iters}")
    print(f"Time elapsed: {training_end-training_start:.3f} seconds")
    print(f"Training throughput: {(block_size*batch_size*max_iters)/(training_end-training_start):>.3f} tokens/sec")

    return model

In [None]:
args.num_iters = 1000
m = train(session_name='test-embedding')

## Generating

In [9]:
# parser.add_argument('-m', '--model', default='checkpoints/sanguogpt-v0.2.pth', help='Input text for training.', type=str)
# parser.add_argument('-l', '--gen_length', default=100, help='Maximum length to generate.', type=int)
# parser.add_argument('--c2i', default='c2i.json', help='Token map file (character to index).', type=str)
# parser.add_argument('--i2c', default='i2c.json', help='Token map file (index to character).', type=str)
# parser.add_argument('--prompt', default=' ', help='Prompt for text generation.', type=str)
# parser.add_argument('--webui', action='store_true', help='If specified, use streamlit-based Web UI instead of command line.')

args = argparse.Namespace()
args.model = 'checkpoints/sanguogpt-v0.2.2.pth'
args.gen_length = 100
args.c2i = 'c2i.json'
args.i2c = 'i2c.json'
args.prompt = ' '
args.webui = False

print(f"Loading model from {args.model}")
model = torch.load(args.model)

c2i, i2c = load_token_map(c2i_file=args.c2i, i2c_file=args.i2c)
print(f"Loading token map file from {args.c2i} and {args.i2c}")

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

print(model)

Loading model from checkpoints/sanguogpt-v0.2.2.pth
Loading token map file from c2i.json and i2c.json
Using mps device
SanGuoGPTModel(
  (token_embedding_table): Embedding(3952, 384)
  (position_embedding_table): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x Head(
            (key): Linear(in_features=384, out_features=48, bias=False)
            (query): Linear(in_features=384, out_features=48, bias=False)
            (value): Linear(in_features=384, out_features=48, bias=False)
            (dropout): Dropout(p=0.001, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.001, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1536, out_features=384, bias=True)
 

In [11]:
def gen_response(prompt:str) -> str:
    start = time.time()
    # encode the prompt into a tensor, and reshape it into (1, T)
    # first dimension is the batch, which is expected by the forward method.
    context = torch.tensor(encoder(prompt, c2i), device=device).unsqueeze(0)
    # model.generate() will truncate the prompt if it's too long, no need to worry about this.
    resp_idx, ppl = model.generate(context, max_new_tokens=args.gen_length) 
    resp = decoder(resp_idx[0].tolist(), i2c)
    end = time.time()
    tokens_generated = min(args.gen_length, len(resp) - len(prompt))
    print(f"{tokens_generated} tokens generated in {end-start:>.3f} seconds, avg {tokens_generated/(end-start):>.3f} tokens/sec.")
    print(f"Perplexity of generation: {ppl[0].item():>.4f}")
    return resp

args.prompt = '关羽'
print(gen_response(args.prompt))

100 tokens generated in 4.266 seconds, avg 23.444 tokens/sec.
Perplexity of generation: 1.0218
关羽次之，张飞为弟。祭罢天地，复宰牛设酒，聚乡中勇士，得三百余人，就桃园中痛饮一醉。来日收拾军器，但恨无马匹可乘。正思虑间，人报有两个客人，引一伙伴当，赶一群马，投庄上来。玄德曰：“此天佑我也！”三人出庄
