In [1]:
import os
os.chdir("..")

In [2]:
from fern.model import Transformer
from fern.config import FernConfig
import torch
from tqdm.notebook import tqdm_notebook
from fern.tokenizer import BytePairEncoding

torch.manual_seed(0)  # type: ignore
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using `{device}` device")

Using `cuda` device


## Data Loading

Here we load a simple dataset (to be released) for training.

In [3]:
bpe = BytePairEncoding.load("checkpoints/tokenizers/tes2304.tok")
bpe

<fern.tokenizer.BytePairEncoding at 0x7ff502b07f40>

In [20]:
data: torch.Tensor = torch.load("data/books_concat_special.pt")
data = data.to(torch.int64).cuda()
n = int(0.8 * len(data))

train_data = data[:n]
val_data = data[n:]

## Define the config and important constants

In [18]:
learning_rate = 3e-4
max_iters = 10000
batch_size = 32
eval_iters = 100
eval_interval = 1000

fern_config = FernConfig(
    d_model=384,
    n_heads=6,
    n_layers=6,
    vocab_size=bpe.vocab_size,
    block_size=512,  # 256
    dropout=0.2,
)

def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - fern_config.block_size, (batch_size,))
    x = torch.stack([data[i : i + fern_config.block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + fern_config.block_size + 1] for i in ix])
    return x, y

@torch.no_grad()  # type: ignore
def estimate_loss(m: torch.nn.Module) -> dict[str, torch.Tensor]:
    out: dict[str, torch.Tensor] = {}
    m.eval()
    for split in ["train", "val"]:
        losses = torch.empty(eval_iters)
        for i in tqdm_notebook(range(eval_iters)):
            bX, bY = get_batch(split)
            _logits, loss = m(bX, bY)
            losses[i] = loss
        out[split] = losses.mean()
    m.train()
    return out

## Training loop

In [12]:
model = Transformer(config=fern_config).to(device)
def param_count(m: torch.nn.Module) -> str:
    total_params = sum(torch.numel(param) for param in m.parameters(True))
    suffixes = ["", "k", "m", "b", "t"]
    i = 0
    while total_params // 1000 != 0:
        total_params /= 1000
        i += 1
    return f"{total_params:.1f}{suffixes[i]}"
    
print(f"Model parameters: {param_count(model)}")

Model parameters: 15.9m


In [21]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in tqdm_notebook(range(max_iters)):
    if iter % eval_interval == 0:
        estimated_loss = estimate_loss(model)
        print(f"Estimated loss at iteration {iter}: {estimated_loss}")
        torch.save(  # type: ignore
            {
                "epoch": iter,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": estimated_loss,
            },
            f"model-{iter}.pt",
        )
    x, y = get_batch("train")
    _, loss = model.forward(x, y)
    optimizer.zero_grad()
    if loss is None:
        raise ValueError("Expected `loss` to be defined during training, got `None`")
    loss.backward() # type: ignore
    optimizer.step()

estimated_loss = estimate_loss(model)
print(f"Estimated loss at iteration {max_iters}: {estimated_loss}")
torch.save(  # type: ignore
    {
        "epoch": max_iters,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": estimated_loss,
    },
    f"model-{max_iters}.pt",
)

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 0: {'train': tensor(7.8962), 'val': tensor(7.8972)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 1000: {'train': tensor(3.4774), 'val': tensor(3.7892)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 2000: {'train': tensor(2.9662), 'val': tensor(3.6138)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 3000: {'train': tensor(2.6168), 'val': tensor(3.6144)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 4000: {'train': tensor(2.3313), 'val': tensor(3.6722)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 5000: {'train': tensor(2.1032), 'val': tensor(3.7286)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 6000: {'train': tensor(1.9141), 'val': tensor(3.8238)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 7000: {'train': tensor(1.7502), 'val': tensor(3.8606)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 8000: {'train': tensor(1.6259), 'val': tensor(3.9153)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 9000: {'train': tensor(1.5117), 'val': tensor(3.9778)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 10000: {'train': tensor(1.4171), 'val': tensor(4.0470)}


NameError: name 'decode' is not defined

In [28]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
model.eval()
generated: list[int] = list(map(lambda x: x.item(), model.generate(context, stop_token=list(bpe.special_token_to_index.values())[-1])))  # type: ignore
print(bpe.decode(generated))

 furname to the fleet. The following week will be making safe, so long as they are truly aloof and scour the weight of our lives. The attacks say the forest are compared to the grounds that are all about animals, each run keeps unpleasantly. Rigurt performs, internal laws often beats the most potent folk of Cyrodiil … just trying to recover the amount of unsettling sorts of dragonlings!
With that decent succinct words, I have learned that the Maormer have a gift for the people. Word of this chaos and the risks. A—
Though I may be there. Perhaps you can see, the laments of our past is eternal. I feel it, this is a strange and loss to go, but a curse no more. It has yet to be wary of the stars, but can also be tracked in mud. Wray for thieves, still others. It doesn’t seem obvious to all those civilizations of Tamriel.
Traditional Tong
A third-handed tale is one of the most famous, and has the odd age for births to retrest a spirit and to speak or secrets clearly. It gives us a character