In [1]:
from time import time
from tqdm import tqdm
from typing import cast
from nanogpt.utils import path_to_resource_file
from nanogpt.encoder import Encoder, TiktokenBasedEncoder

import mlx.core as mx
import mlx.nn as mlx_nn
from mlx import optimizers
from mlx.nn import losses
from mlx.nn.utils import value_and_grad

import torch
import torch.nn as torch_nn

from nanogpt import mlx_
from nanogpt import torch_

In [2]:
with open(path_to_resource_file('gutenberg_shakespeare_st.txt'), "r") as f:
    text_st = f.read()

def format_time(start_time: float, end_time: float) -> str:
    delta = end_time - start_time
    m = int(delta) // 60
    s = int(delta) % 60
    return f'{m:02}:{s:02}{f"{delta-int(delta):.3f}"[1:]}'

In [3]:
encoder = TiktokenBasedEncoder(text_st)

batch_size = 32
context_length = 32
embedding_size = 64
num_heads = 4
num_blocks = 4
dropout = .2

learning_rate = 4e-4

epochs = 1000
max_new_tokens = 1000

# MLX

In [4]:
gpu_dev_type = mx.DeviceType(1)
gpu = mx.Device(gpu_dev_type)
mx.set_default_device(gpu)
mx.default_device()

Device(gpu, 0)

In [5]:
def estimate_loss(model: mlx_nn.Module, data: mlx_.Data, batch_size: int, block_size: int, *, eval_iters: int = 100):
    out = {}
    model.eval()
    for split in ['train', 'test']:
        losses = mx.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = data.get_batch(split, batch_size=batch_size, block_size=block_size)  # type: ignore
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

def generate_text(model: mlx_.NanoGPT, encoder: Encoder, init_text: str, *, max_new_tokens: int = 1000):
    t = encoder.encode(init_text)
    idx = mx.array([t], dtype=mx.int16)
    for token in model.generate(idx, max_new_tokens=max_new_tokens):
        print(encoder.decode(token[0].tolist()), end='', flush=True)

In [6]:
gpt = mlx_.NanoGPT(vocab_size=len(encoder), 
                   embedding_size=embedding_size, 
                   context_length=context_length, 
                   num_heads=num_heads, 
                   num_blocks=num_blocks, 
                   dropout=dropout)
gpt.apply_to_modules(mlx_.initialize_weights)
data = mlx_.Data(mx.array(encoder.encode(text_st), dtype=mx.int32), split=.9)

In [7]:
def count_mlx_params(model: mlx_nn.Module) -> int:
    params = 0
    def handle_list(l: list):
        nonlocal params
        for v in l:
            if isinstance(v, dict):
                handle_dict(v)
            elif isinstance(v, list):
                handle_list(v)
            elif isinstance(v, mx.array):
                params += v.size
            else:
                print('??', type(v))
        
    def handle_dict(d: dict):
        nonlocal params
        for _, v in d.items():
            if isinstance(v, dict):
                handle_dict(v)
            elif isinstance(v, list):
                handle_list(v)
            elif isinstance(v, mx.array):
                params += v.size
            else:
                print('??', type(v))

    handle_dict(model.parameters())
    return params

total_params = count_mlx_params(gpt)
print(f'Model contains {total_params/1e6:.1f}M parameters ({total_params})')

Model contains 4.7M parameters (4743288)


In [8]:
# Training
optimizer = optimizers.AdamW(learning_rate=learning_rate)
optimizer.init(gpt.trainable_parameters())
loss_fn = lambda x, y: losses.cross_entropy(gpt(x)[0], y, reduction='mean')
grad_fn = value_and_grad(gpt, loss_fn)

print('Initial loss:', estimate_loss(gpt, data, batch_size, context_length))
start_time = time()
for _ in tqdm(range(epochs)):
    xb, yb = data.get_batch('train', batch_size=batch_size, block_size=context_length)
    __, grads = grad_fn(xb, yb)
    optimizer.update(gpt, grads)
    mx.eval(gpt.state)
end_time = time()
print(f'Training time ({epochs} epochs):', format_time(start_time, end_time), f'[{epochs/(end_time-start_time)} epoch/sec]')
print('Final loss:', estimate_loss(gpt, data, batch_size, context_length))

Initial loss: {'train': array(10.067, dtype=float32), 'test': array(10.067, dtype=float32)}


100%|██████████| 1000/1000 [00:39<00:00, 25.26it/s]


Training time (1000 epochs): 00:39.594 [25.25624666687923 epoch/sec]
Final loss: {'train': array(5.8827, dtype=float32), 'test': array(6.32605, dtype=float32)}


In [9]:
start_time = time()
generate_text(gpt, encoder, '§', max_new_tokens=max_new_tokens)
end_time = time()
print(f'\n-----\nInference time ({max_new_tokens} tokens):', format_time(start_time, end_time), f'[{max_new_tokens/(end_time-start_time)} T/sec]')

IN
orph purple so; best dig prince him will for a.

OR

INOER antic charge hell,
 SC.
Is been TalStand of honestish, sweet he Anne; inensible forth! I nails to the other good? fears,
A of teach in hunt,
BA done.

Y of the ink, and remembrance in in gods it thou confess presently.

 ANDFirst
 serv seven soft?ona or hast Somerset out to the
Till discretion been herd go me against your lord.

LYCANDAND.
And an princeMake happy.

Exe hack cheeks now up me.
Then if be grace’s uncle growing.PAR good Thou Titus with with tokens back of crossed me after toys of injustice of my power a govern upon you’s hold
Which in one
That upon one,e, I had seen;’d for much:
My most gold’d and by their this curtain out
 sepinks not, he to Isis thou not you to they think myities, slutt
 horizon from be
Love the wanderingAS, wounds now
And are of thy lute.

j merchant; goes, was thee would heart and your balm._] friend. father.—No; andes by, thr gar Jackil’darel leave and never he he me, this devatherine happy

# PyTorch

In [10]:
torch.set_default_device('mps')   # Apple Metal
torch.get_default_device()

device(type='mps', index=0)

In [11]:
@torch.no_grad()
def estimate_loss(model: torch_nn.Module, data: torch_.Data, batch_size: int, block_size: int, *, eval_iters: int = 100):
    out = {}
    model.eval()
    for split in ['train', 'test']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = data.get_batch(split, batch_size=batch_size, block_size=block_size)  # type: ignore
            _, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

@torch.no_grad()
def generate_text(model: torch_.NanoGPT, encoder: Encoder, init_text: str, *, max_new_tokens: int = 1000):
    t = encoder.encode(init_text)
    idx = torch.tensor([t], dtype=torch.long)
    for token in model.generate(idx, max_new_tokens=max_new_tokens):
        print(encoder.decode(token[0].tolist()), end='', flush=True)

In [12]:
gpt = torch_.NanoGPT(vocab_size=len(encoder), 
                     embedding_size=embedding_size, 
                     context_length=context_length, 
                     num_heads=num_heads, 
                     num_blocks=num_blocks, 
                     dropout=dropout)
gpt.apply(torch_.initialize_weights)
data = torch_.Data(torch.tensor(encoder.encode(text_st), dtype=torch.long), split=.9)

In [13]:
total_params = sum(p.numel() for p in gpt.parameters())
print(f'Model contains {total_params/1e6:.1f}M parameters ({total_params})')

Model contains 4.7M parameters (4743288)


In [14]:
# Training
optimizer = torch.optim.AdamW(gpt.parameters(), lr=learning_rate)
print('Initial loss:', estimate_loss(gpt, data, batch_size, context_length))
for _ in tqdm(range(epochs)):
    xb, yb = data.get_batch('train', batch_size=batch_size, block_size=context_length)
    logits, loss = gpt(xb, yb)
    loss = cast(torch.Tensor, loss)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()
end_time = time()
print(f'Training time ({epochs} epochs):', format_time(start_time, end_time), f'[{epochs/(end_time-start_time)} epoch/sec]')
print('Final loss:', estimate_loss(gpt, data, batch_size, context_length))

Initial loss: {'train': tensor(10.0637, device='mps:0'), 'test': tensor(10.0642, device='mps:0')}


100%|██████████| 1000/1000 [01:31<00:00, 10.90it/s]


Training time (1000 epochs): 01:47.354 [9.315014696924038 epoch/sec]
Final loss: {'train': tensor(5.6972, device='mps:0'), 'test': tensor(6.1773, device='mps:0')}


In [15]:
start_time = time()
generate_text(gpt, encoder, '§', max_new_tokens=max_new_tokens)
end_time = time()
print(f'\n-----\nInference time ({max_new_tokens} tokens):', format_time(start_time, end_time), f'[{max_new_tokens/(end_time-start_time)} T/sec]')

 expects sir, let done for not hus happily?
Four’dIEFain, holland.

GENE.
 lone Cl Speak the therein mye man of his hell
the orderither a man’s fond; I, to some seeming my times
-law’st the kings
FAST loud’d, do ear well to that in his Signuck’d as I
As I will
h  masculine like all thus close WAAL [_ near Got an hold thee,
 coppulk practice.

HOT rod: then I is blood but drink; and a sworn at themeaners.

ARD.
ANTARD.
For us will praise for himself
With surprised
I’ll even, Protehood
Still bear theecutHary church abroad,nothing again to go?
Which; approve me you shall shall hold to keepOr you else?

 porter.

LEMAN.
Ay, had this voice to the aboard and that think me, which not dispense who we know theher

Lay heard,
I youraatory, and ha embrace;
It is ’t.

 while to the army look, this-ro811, and match:, sir of draw king
of matter on sceptKSTARD.
By avoid to be noble cousin that blessed aBAoth howBrown of me’s anyin at like where,
 temper designs,” Iperolph, i’ nobian-b tyrILILYCE.
O b