In [154]:
import torch
from tiny_torch import *

import random
import torch

words = open('names.txt', 'r').read().splitlines()
words = sorted(list(set(words))) # set cause uncontrollable randomness， sorted for reproducibility
max_len = max(len(w) for w in words)
random.seed(42)
random.shuffle(words)

chs = list(set(''.join(words + ['.']))) # add special token
chs = sorted(chs, reverse=False)
stoi = {ch: i for i, ch in enumerate(chs)}
itos = {i: ch for i, ch in enumerate(chs)}

# predict next token use previous tokens
vocab_size = len(chs)
block_size = max_len + 1
X, Y = [], []

for w in words:
    x = torch.zeros(max_len + 1, dtype=torch.long)
    y = torch.zeros(max_len + 1, dtype=torch.long)
    x[1:1+len(w)] = torch.tensor([stoi[ch] for ch in w])
    y[:len(w)] = torch.tensor([stoi[ch] for ch in w])
    y[len(w)+1:] = -1 # mask the loss at the inactive locations
    X.append(x)
    Y.append(y)

X = torch.stack(X)
Y = torch.stack(Y)
n1, n2  = int(0.8 * len(X)), int(0.9 * len(X))

X_train, X_val, X_test = X.tensor_split([n1, n2])
Y_train, Y_val, Y_test = Y.tensor_split([n1, n2])

show = 20
for x, y in zip(X_train[:show], Y_train[:show]):
    sx = ''.join(itos[i.item()] for i in x)
    sy = ''.join(itos[i.item()] for i in y if i.item() != -1)
    print(f'{sx} -> {sy}')

X_train.shape, X_val.shape, X_test.shape, Y_train.shape, Y_val.shape, Y_test.shape


.aukai.......... -> aukai.
.ellanore....... -> ellanore.
.liem........... -> liem.
.aquarius....... -> aquarius.
.joangel........ -> joangel.
.wryn........... -> wryn.
.isabela........ -> isabela.
.astryd......... -> astryd.
.maleik......... -> maleik.
.emerick........ -> emerick.
.natasha........ -> natasha.
.kasandra....... -> kasandra.
.aevin.......... -> aevin.
.brason......... -> brason.
.naiara......... -> naiara.
.alanna......... -> alanna.
.raunak......... -> raunak.
.gohan.......... -> gohan.
.ivie........... -> ivie.
.alandis........ -> alandis.


(torch.Size([23595, 16]),
 torch.Size([2949, 16]),
 torch.Size([2950, 16]),
 torch.Size([23595, 16]),
 torch.Size([2949, 16]),
 torch.Size([2950, 16]))

# move to manual backprop

In [155]:
from dataclasses import dataclass

@dataclass
class ModelConfig:
    block_size: int = None # length of the input sequences of integers
    vocab_size: int = None # the input integers are in range [0 .. vocab_size -1]
    # parameters below control the sizes of each model slightly differently
    n_layer: int = 4
    n_embd: int = 64
    n_embd2: int = 64
    n_head: int = 4
    dtype: torch.dtype = torch.float64

## loss

In [156]:
class CrossEntropyLoss3d(Module):

    def __repr__(self):
        return f'MyCrossEntropyLoss()'

    def __call__(self, logits, y, ignore_index=-1):
        assert logits.ndim == 3, 'only verify for 3d logits (B, T, C)'
        B, T, V = logits.shape
        max_logits = logits.max(dim=-1, keepdim=True).values
        probs = (logits - max_logits).softmax(dim=-1)
        mask = (y != ignore_index)
        loss = -probs[torch.arange(B)[:, None], torch.arange(T)[None, :], y].log() * mask # indices also need to be broadcasted, we use [None, :]
        loss = loss.sum() / mask.sum()
        # backward buffer
        self.probs = probs
        self.y = y
        self.mask = mask
        return loss
        
    
    def backward(self, grad=1.0): # grad is the gradient of the loss function, usually 1.0
        probs, y, mask = self.probs, self.y, self.mask
        B, T, V = probs.shape
        x_grad = probs.clone()
        x_grad[torch.arange(B)[:, None], torch.arange(T)[None, :], y] -= 1.0
        x_grad = x_grad * mask.unsqueeze(-1)
        x_grad = x_grad / mask.sum() * grad
        return x_grad


## mlp

In [157]:
class MLP(Module):
    """
    takes the previous block_size tokens, encodes them with a lookup table,
    concatenates the vectors and predicts the next token with an MLP.

    Reference:
    Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
    """

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.wte = Embedding(config.vocab_size + 1, config.n_embd, dtype=config.dtype) # token embeddings table
        # +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
        # before the beginning of the input sequence
        self.mlp = Sequential([
            Linear(self.block_size * config.n_embd, config.n_embd2, dtype=config.dtype),
            Tanh(),
            Linear(config.n_embd2, self.vocab_size, dtype=config.dtype)
        ])
        self.mlp[-1].weight.data *= 0.1
        self.mlp[-1].bias.data *= 0.01
        n_params = sum(p.numel() for p in self.parameters())
        print("number of transformer parameters: %d" % (n_params,))
        self.config = config
    
    def parameters(self):
        return list(self.wte.parameters()) + list(self.mlp.parameters())
    
    def grads(self):
        return list(self.wte.grads()) + list(self.mlp.grads())

    def get_block_size(self):
        return self.block_size

    def __call__(self, idx, targets=None):

        # gather the word embeddings of the previous 3 words
        idx_buf = []
        embs = []
        for k in range(self.block_size):
            tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
            idx_buf.append(idx.unsqueeze(-1))
            embs.append(tok_emb)
            idx = torch.roll(idx, 1, 1)
            idx[:, 0] = self.vocab_size # special <BLANK> token

        # concat all of the embeddings together and pass through an MLP
        x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
        logits = self.mlp(x)

        # backward buffer
        self.idx_buf = torch.cat(idx_buf, -1) # (b, t, t)


        return logits
    
    def backward(self, grad):
        grad = self.mlp.backward(grad)
        # mlp backprop to wte
        b, t, _ = grad.shape # (b, t, n_embd * block_size)
        grad = grad.view(b * t * self.config.block_size, self.config.n_embd) # (b*t*block_size, n_embd)
        wte_weight = self.wte.weight
        wte_grad = torch.zeros_like(wte_weight)
        wte_grad.index_add_(dim=0, index=self.idx_buf.view(-1), source=grad)
        self.wte.weight_grad = wte_grad


In [158]:
import torch.nn as nn

class MLPtorch(nn.Module):
    """
    takes the previous block_size tokens, encodes them with a lookup table,
    concatenates the vectors and predicts the next token with an MLP.

    Reference:
    Bengio et al. 2003 https://www.jmlr.org/papers/volume3/bengio03a/bengio03a.pdf
    """

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.wte = nn.Embedding(config.vocab_size + 1, config.n_embd) # token embeddings table
        # +1 in the line above for a special <BLANK> token that gets inserted if encoding a token
        # before the beginning of the input sequence
        self.mlp = nn.Sequential(
            nn.Linear(self.block_size * config.n_embd, config.n_embd2),
            nn.Tanh(),
            nn.Linear(config.n_embd2, self.vocab_size)
        )
        n_params = sum(p.numel() for p in self.parameters())
        print("number of transformer parameters: %d" % (n_params,))

    def get_block_size(self):
        return self.block_size

    def forward(self, idx, targets=None):

        # gather the word embeddings of the previous 3 words
        embs = []
        for k in range(self.block_size):
            tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd)
            idx = torch.roll(idx, 1, 1)
            idx[:, 0] = self.vocab_size # special <BLANK> token
            embs.append(tok_emb)

        # concat all of the embeddings together and pass through an MLP
        x = torch.cat(embs, -1) # (b, t, n_embd * block_size)
        logits = self.mlp(x)

        return logits


In [159]:
import torch.nn.functional as F

config = ModelConfig(block_size=block_size, vocab_size=vocab_size, n_embd=8, n_embd2=24)
# models
torch.manual_seed(42)
model_mlp = MLP(config)
model_mlp_t = MLPtorch(config)
# copy weights
model_mlp_t.wte.weight.data = model_mlp.wte.weight.data.clone()
for i, (p_t, p) in enumerate(zip(model_mlp_t.mlp.parameters(), model_mlp.mlp.parameters())):
    if p.dim() == 2:
        p_t.data = p.data.clone().T # linear layer weight
    else:
        p_t.data = p.data.clone()
# loss
loss_fn = CrossEntropyLoss3d()


number of transformer parameters: 3995
number of transformer parameters: 3995


In [160]:
models = {
    'mlp': model_mlp,
    'mlp_t': model_mlp_t,
}

# args
n_steps = 100
eval_every = 10
bs = 32
ini_lr = 0.1

# train
lossi = []
torch.manual_seed(42)
for  step in range(n_steps):
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]
    lr = 0.1 #ini_lr if step < int(n_steps * 0.75) else ini_lr * 0.1

    # --- torch ---
    # forward
    logits_t = model_mlp_t(x)
    loss_t = F.cross_entropy(logits_t.view(-1, logits_t.size(-1)), y.view(-1), ignore_index=-1)
    # backward
    loss_t.backward()
    # update
    for p_t in model_mlp_t.parameters():
        p_t.data -= lr * p_t.grad
        p_t.grad = None
    # --- manual ---
    # forward
    logits = model_mlp(x)
    loss = loss_fn(logits, y)
    # backward
    grad = loss_fn.backward()
    model_mlp.backward(grad)
    # update
    for p, g in zip(model_mlp.parameters(), model_mlp.grads()):
        p.data -= lr * g
    
    
    # eval
    if step % eval_every == 0:
        x, y = X_val, Y_val
        with torch.no_grad():
            logits_t = model_mlp_t(x)
            val_loss_t = F.cross_entropy(logits_t.view(-1, logits_t.size(-1)), y.view(-1), ignore_index=-1)
        logits = model_mlp(x)
        val_loss = loss_fn(logits, y)
        print(f'step {step:<4} || Train || {loss.item():.6f} || Val || {val_loss.item():.6f} || {val_loss_t.item():.6f}')
        
    # break


step 0    || Train || 3.296238 || Val || 3.290067 || 3.290067
step 10   || Train || 3.229404 || Val || 3.226768 || 3.226768
step 20   || Train || 3.161818 || Val || 3.153679 || 3.153679
step 30   || Train || 3.069992 || Val || 3.056860 || 3.056860
step 40   || Train || 2.948184 || Val || 2.958030 || 2.958030
step 50   || Train || 2.949450 || Val || 2.900734 || 2.900734
step 60   || Train || 2.862980 || Val || 2.867335 || 2.867335
step 70   || Train || 2.798518 || Val || 2.845743 || 2.845743
step 80   || Train || 2.855532 || Val || 2.830552 || 2.830552
step 90   || Train || 2.882640 || Val || 2.817070 || 2.817070


## rnn

In [161]:
class RNN(Module):

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd
        self.n_embd2 = config.n_embd2
        self.start = torch.zeros(1, config.n_embd2, dtype=config.dtype) # the starting hidden state
        self.wte = Embedding(config.vocab_size, config.n_embd, dtype=config.dtype) # token embeddings table
        self.Cw = Linear(config.n_embd + config.n_embd2, config.n_embd2, dtype=config.dtype) # rnn cell weight
        self.lm_head = Linear(config.n_embd2, self.vocab_size, dtype=config.dtype)
        # grads
        self.start_grad = None
    
    def parameters(self):
        return [self.start] + list(self.wte.parameters()) + list(self.Cw.parameters()) + list(self.lm_head.parameters())
    
    def grads(self):
        return [self.start_grad] + list(self.wte.grads()) + list(self.Cw.grads()) + list(self.lm_head.grads())

    def get_block_size(self):
        return self.block_size

    def __call__(self, x):
        b, t = x.size()
        emb = self.wte(x) # (b, t, n_embd)
        # sequentially iterate over the inputs and update the RNN state each tick
        hprev = self.start.expand((b, -1)) # expand out the batch dimension
        hiddens = []
        emb_cat_hprevs = []
        for i in range(t):
            xt = emb[:, i, :] # (b, n_embd)
            emb_i_cat_hprev = torch.cat([xt, hprev], dim=1)
            # --- rnn cell ---
            hi = self.Cw(emb_i_cat_hprev)
            hi = hi.tanh()
            # --------------
            hprev = hi
            hiddens.append(hi)
            emb_cat_hprevs.append(emb_i_cat_hprev)
        # decode the outputs
        hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
        logits = self.lm_head(hidden)
        # backward buffer
        self.hidden = hidden
        self.emb_cat_hprevs = emb_cat_hprevs
        return logits

    def backward(self, grad):
        hidden, emb_cat_hprevs = self.hidden, self.emb_cat_hprevs
        t = hidden.size(1)
        dhidden = self.lm_head.backward(grad)
        # logits grad to start, wte, Cw grad
        dembs = []
        dCw, dhprev = 0., 0.
        if self.Cw.bias is not None:
            dCw_bias = 0.
        for i in range(t-1, -1, -1):
            # hidden state grad, emb grad
            dhi = dhidden[:, i, :] + dhprev # grad from logits + grad from prev hidden state
            hi = hidden[:, i, :]
            dhi = (1 - hi**2) * dhi # grad of tanh
            demb_i_cat_dhi = dhi @ self.Cw.weight.T
            demb_i, dhprev = demb_i_cat_dhi.tensor_split([self.n_embd,], dim=1)
            dembs.append(demb_i)
            # cell weight grad
            emb_i_cat_hprev = emb_cat_hprevs[i]
            dCw += emb_i_cat_hprev.T @ dhi
            if self.Cw.bias is not None:
                dCw_bias += dhi.sum(dim=0)
        dstart = dhprev.sum(dim=0, keepdim=True)
        demb = torch.stack(dembs[::-1], 1)
        self.wte.backward(demb)
        self.start_grad = dstart
        self.Cw.weight_grad = dCw
        if self.Cw.bias is not None:
            self.Cw.bias_grad = dCw_bias

In [162]:
import torch.nn as nn

class RNNtorch(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.block_size = config.block_size
        self.vocab_size = config.vocab_size
        self.n_embd = config.n_embd
        self.n_embd2 = config.n_embd2
        self.start = nn.Parameter(torch.zeros(1, config.n_embd2)) # the starting hidden state
        self.wte = nn.Embedding(config.vocab_size, config.n_embd) # token embeddings table
        self.Cw = nn.Linear(config.n_embd + config.n_embd2, config.n_embd2)
        self.lm_head = nn.Linear(config.n_embd2, self.vocab_size)

    def get_block_size(self):
        return self.block_size

    def forward(self, idx):
        b, t = idx.size()

        # embed all the integers up front and all at once for efficiency
        emb = self.wte(idx) # (b, t, n_embd)

        # sequentially iterate over the inputs and update the RNN state each tick
        hprev = self.start.T.expand((b, -1)) # expand out the batch dimension
        hiddens = []
        for i in range(t):
            xt = emb[:, i, :] # (b, n_embd)
            xh = torch.cat([xt, hprev], dim=1)
            ht = (self.Cw(xh)).tanh()
            hprev = ht
            hiddens.append(ht)

        # decode the outputs
        hidden = torch.stack(hiddens, 1) # (b, t, n_embd2)
        logits = self.lm_head(hidden)

        return logits

In [163]:
import torch.nn.functional as F

config = ModelConfig(block_size=16, vocab_size=27, n_layer=None, n_embd=2, n_embd2=3, n_head=None)
# models
torch.manual_seed(42)
model_rnn = RNN(config)
model_rnn_t = RNNtorch(config)
# copy weights
model_rnn_t.wte.weight.data = model_rnn.wte.weight.data.clone()
for i, (p_t, p) in enumerate(zip(model_rnn_t.parameters(), model_rnn.parameters())):
    if p.dim() == 2 and i != 1: # skip the embedding layer
        p_t.data = p.data.clone().T # linear layer weight
    else:
        p_t.data = p.data.clone()
# loss
loss_fn = CrossEntropyLoss3d()

In [164]:
models = {
    'rnn': model_rnn,
    'rnn_t': model_rnn_t,
}

# args
n_steps = 100
eval_every = 10
bs = 32
ini_lr = 0.1

# train
lossi = []
torch.manual_seed(42)
for  step in range(n_steps):
    idx = torch.randint(0, X_train.shape[0], (bs,))
    x, y = X_train[idx], Y_train[idx]
    lr = 0.1 #ini_lr if step < int(n_steps * 0.75) else ini_lr * 0.1

    # --- torch ---
    # forward
    logits_t = model_rnn_t(x)
    loss_t = F.cross_entropy(logits_t.view(-1, logits_t.size(-1)), y.view(-1), ignore_index=-1)
    # backward
    loss_t.backward()
    # update
    for p_t in model_rnn_t.parameters():
        p_t.data -= lr * p_t.grad
        p_t.grad = None
    # --- manual ---
    # forward
    logits = model_rnn(x)
    loss = loss_fn(logits, y)
    # backward
    grad = loss_fn.backward()
    model_rnn.backward(grad)
    # update
    for p, g in zip(model_rnn.parameters(), model_rnn.grads()):
        p.data -= lr * g
    
    
    # eval
    if step % eval_every == 0:
        x, y = X_val, Y_val
        with torch.no_grad():
            logits_t = model_rnn_t(x)
            val_loss_t = F.cross_entropy(logits_t.view(-1, logits_t.size(-1)), y.view(-1), ignore_index=-1)
        logits = model_rnn(x)
        val_loss = loss_fn(logits, y)
        print(f'step {step:<4} || Train || {loss.item():.6f} || Val || {val_loss.item():.6f} || Val_t || {val_loss_t.item():.6f}')
        
    # break


step 0    || Train || 3.306277 || Val || 3.278128 || Val_t || 3.278128
step 10   || Train || 3.254764 || Val || 3.225257 || Val_t || 3.225257
step 20   || Train || 3.187834 || Val || 3.179277 || Val_t || 3.179277
step 30   || Train || 3.144083 || Val || 3.136274 || Val_t || 3.136274
step 40   || Train || 3.103379 || Val || 3.095912 || Val_t || 3.095912
step 50   || Train || 3.091513 || Val || 3.060530 || Val_t || 3.060530
step 60   || Train || 3.041352 || Val || 3.028438 || Val_t || 3.028438
step 70   || Train || 2.983515 || Val || 2.997443 || Val_t || 2.997443
step 80   || Train || 3.013324 || Val || 2.970027 || Val_t || 2.970027
step 90   || Train || 2.978646 || Val || 2.946448 || Val_t || 2.946448
