In [None]:
import os
os.chdir("..")

In [None]:
from fern.model import Transformer
from fern.config import FernConfig
import torch
from tqdm.notebook import tqdm_notebook

torch.manual_seed(0)  # type: ignore
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using `{device}` device")

## Data Loading

Here we load a simple dataset (to be released) for training.

In [None]:
with open("data/books_concat.txt", "r") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
itos = dict(enumerate(chars))
stoi = dict((s, i) for i, s in itos.items())


def encode(seq: str) -> list[int]:
    return [stoi[ch] for ch in seq]


def decode(seq: list[int]) -> str:
    return "".join(itos[ix] for ix in seq)


data = torch.tensor(encode(text), dtype=torch.long, device=device)
n = int(0.8 * len(data))

train_data = data[:n]
val_data = data[n:]

## Define the config and important constants

In [None]:
learning_rate = 3e-4
max_iters = 10000
batch_size = 32
eval_iters = 100
eval_interval = 1000

fern_config = FernConfig(
    d_model=384,
    n_heads=6,
    n_layers=6,
    vocab_size=vocab_size,
    block_size=512,  # 256
    dropout=0.2,
)

def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - fern_config.block_size, (batch_size,))
    x = torch.stack([data[i : i + fern_config.block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + fern_config.block_size + 1] for i in ix])
    return x, y

@torch.no_grad()  # type: ignore
def estimate_loss(m: torch.nn.Module) -> dict[str, torch.Tensor]:
    out: dict[str, torch.Tensor] = {}
    m.eval()
    for split in ["train", "val"]:
        losses = torch.empty(eval_iters)
        for i in tqdm_notebook(range(eval_iters)):
            bX, bY = get_batch(split)
            _logits, loss = m(bX, bY)
            losses[i] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

## Training loop

In [None]:

model = Transformer(fern_config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
for iter in tqdm_notebook(range(max_iters)):
    if iter % eval_interval == 0:
        estimated_loss = estimate_loss(model)
        print(f"Estimated loss at iteration {iter}: {estimated_loss}")
        torch.save(  # type: ignore
            {
                "epoch": iter,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": estimated_loss,
            },
            f"model-{iter}.pt",
        )
    x, y = get_batch("train")
    _, loss = model.forward(x, y)
    optimizer.zero_grad()
    if loss is None:
        raise ValueError("Expected `loss` to be defined during training, got `None`")
    loss.backward() # type: ignore
    optimizer.step()

estimated_loss = estimate_loss(model)
print(f"Estimated loss at iteration {max_iters}: {estimated_loss}")
torch.save(  # type: ignore
    {
        "epoch": max_iters,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": estimated_loss,
    },
    f"model-{max_iters}.pt",
)

context = torch.zeros((1, 1), dtype=torch.long, device=device)
model.eval()
generated: list[int] = model.get_generation(context, 512)[0].tolist()  # type: ignore
print(decode(generated))