## Train a character-level GPT on some text data

The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some Shakespeare, which we'll get it to predict character-level.

In [1]:
# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%d/%m/%Y %H:%M:%S",
        level=logging.INFO,
)

In [2]:
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

In [None]:
import jax
import jax.numpy as jnp
import haiku as hk
from functools import partial

from mingpt.chardataset import CharDataset

In [4]:
jax.default_backend(), jax.local_devices(), jax.device_count()

14/01/2025 17:38:03 - INFO - jax._src.xla_bridge -   Unable to initialize backend 'cuda': 
14/01/2025 17:38:03 - INFO - jax._src.xla_bridge -   Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
14/01/2025 17:38:03 - INFO - jax._src.xla_bridge -   Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/Users/yoavram/miniforge3/envs/minGPT/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)


('cpu', [CpuDevice(id=0)], 1)

In [5]:
# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt
text = open('input.txt', 'r').read() 
train_dataset = CharDataset(text, block_size = 128) # one line of poem is roughly 50 characters

data has 1115394 characters, 65 unique.


In [6]:
from mingpt.model import gpt, loss_fn, GPTConfig

rng = jax.random.key(242)
gpt_config = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=8, n_head=8, n_embd=512)

In [7]:
hk_loss_fn = hk.transform(partial(loss_fn, config=gpt_config, is_training=True))

In [8]:
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
rng, subkey = jax.random.split(rng)
tconf = TrainerConfig(max_epochs=2, batch_size=512//2, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, 
                      final_tokens=2*len(train_dataset)*train_dataset.block_size,
                      num_workers=4, rng=subkey)
trainer = Trainer(hk_loss_fn, train_dataset, None, tconf)

In [9]:
params = trainer.init_params() 

14/01/2025 17:38:51 - INFO - mingpt.trainer -   number of parameters: 25352192


jaxlib.xla_extension.ArrayImpl

In [13]:
params, _ = trainer.train(params)

epoch 1 iter 4356: train loss 0.23807. lr 3.000718e-04: 100%|██████████| 4357/4357 [08:48<00:00,  8.24it/s]
epoch 2 iter 8713: train loss 0.14488. lr 6.000000e-05: 100%|██████████| 4357/4357 [08:26<00:00,  8.59it/s]


In [16]:
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample

In [17]:
model = hk.transform(partial(gpt, config=gpt_config, is_training=False))
model = hk.without_apply_rng(model).apply

In [None]:
context = "O God, O God!"
x = jnp.array([train_dataset.stoi[s] for s in context])
y = sample(params, model, gpt_config, x, 2000, temperature=1.0, sample=True, top_k=10, progress=True)
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print(completion)

 30%|██▉       | 593/2000 [03:45<05:28,  4.29it/s]

In [17]:
# well that was fun