In [1]:
import os
os.chdir("..")

In [2]:
from fern.model import Transformer
from fern.config import FernConfig
import torch
from tqdm.notebook import tqdm_notebook
from fern.tokenizer import BytePairEncoding

torch.manual_seed(0)  # type: ignore
torch.set_float32_matmul_precision("high")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using `{device}` device")

Using `cuda` device


## Data Loading

Here we load a simple dataset (to be released) for training.

In [3]:
bpe = BytePairEncoding.load("checkpoints/tokenizers/tes2304.tok")
bpe

<fern.tokenizer.BytePairEncoding at 0x7f1f9415be50>

In [4]:
data: torch.Tensor = torch.load("data/books_concat_special.pt")
data = data.to(torch.int64).cuda()
n = int(0.8 * len(data))

train_data = data[:n]
val_data = data[n:]

## Define the config and important constants

In [5]:
learning_rate = 3e-4
max_iters = 10000
batch_size = 32
eval_iters = 100
eval_interval = 1000

fern_config = FernConfig(
    d_model=128,  #384
    n_heads=8,
    n_layers=32,
    vocab_size=bpe.vocab_size,
    block_size=512,  # 256
    dropout=0.2,
)

def get_batch(split: str) -> tuple[torch.Tensor, torch.Tensor]:
    data = train_data if split == "train" else val_data
    ix = torch.randint(len(data) - fern_config.block_size, (batch_size,))
    x = torch.stack([data[i : i + fern_config.block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + fern_config.block_size + 1] for i in ix])
    return x, y

@torch.no_grad()  # type: ignore
def estimate_loss(m: torch.nn.Module) -> dict[str, torch.Tensor]:
    out: dict[str, torch.Tensor] = {}
    m.eval()
    for split in ["train", "val"]:
        losses = torch.empty(eval_iters)
        for i in tqdm_notebook(range(eval_iters)):
            bX, bY = get_batch(split)
            _logits, loss = m(bX, bY)
            losses[i] = loss
        out[split] = losses.mean()
    m.train()
    return out

## Training loop

In [6]:
model = Transformer(config=fern_config).to(device)
def param_count(m: torch.nn.Module) -> str:
    total_params = sum(torch.numel(param) for param in m.parameters(True))
    suffixes = ["", "k", "m", "b", "t"]
    i = 0
    while total_params // 1000 != 0:
        total_params /= 1000
        i += 1
    return f"{total_params:.1f}{suffixes[i]}"
    
print(f"Model parameters: {param_count(model)}")

Model parameters: 9.0m


In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scaler = torch.cuda.amp.GradScaler()

for iter in tqdm_notebook(range(max_iters)):
    with torch.autocast(device, torch.bfloat16):
        if iter % eval_interval == 0:
            estimated_loss = estimate_loss(model)
            print(f"Estimated loss at iteration {iter}: {estimated_loss}")
            torch.save(  # type: ignore
                {
                    "epoch": iter,
                    "model_state_dict": model.state_dict(),
                    "optimizer_state_dict": optimizer.state_dict(),
                    "loss": estimated_loss,
                },
                f"model-{iter}.pt",
            )
        x, y = get_batch("train")
        _, loss = model.forward(x, y)
    # optimizer.zero_grad()
    if loss is None:
        raise ValueError("Expected `loss` to be defined during training, got `None`")
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    optimizer.zero_grad()
    # loss.backward() # type: ignore
    # optimizer.step()

estimated_loss = estimate_loss(model)
print(f"Estimated loss at iteration {max_iters}: {estimated_loss}")
torch.save(  # type: ignore
    {
        "epoch": max_iters,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": estimated_loss,
    },
    f"model-{max_iters}.pt",
)

  0%|          | 0/10000 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 0: {'train': tensor(8.3886), 'val': tensor(8.3873)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 1000: {'train': tensor(4.2836), 'val': tensor(4.3916)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 2000: {'train': tensor(3.7055), 'val': tensor(3.9195)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 3000: {'train': tensor(3.4469), 'val': tensor(3.7515)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 4000: {'train': tensor(3.2808), 'val': tensor(3.6570)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 5000: {'train': tensor(3.1557), 'val': tensor(3.6049)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 6000: {'train': tensor(3.0463), 'val': tensor(3.5842)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 7000: {'train': tensor(2.9420), 'val': tensor(3.5473)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 8000: {'train': tensor(2.8791), 'val': tensor(3.5328)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 9000: {'train': tensor(2.7953), 'val': tensor(3.5346)}


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Estimated loss at iteration 10000: {'train': tensor(2.7412), 'val': tensor(3.5207)}


In [8]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
model.eval()
generated: list[int] = list(map(lambda x: x.item(), model.generate(context, stop_token=list(bpe.special_token_to_index.values())[-1])))  # type: ignore
print(bpe.decode(generated))

et’s armor and foolish shields open for a time. Primarily the griless diamond boulevard leaps back to the leg with the wisely fitted face as five times rounded and flung off into the back sharp for fellows.
Arctic leather notoriously secretively durable enough to alike that when in case the lensarly swamp rings around it may have personally intricate ability. Wing just as the entire great scenician gems has been forged to points of vanity dressed in the armor and is truly forced to take its beating stone below.
CHEST PIECES
Our scales are of wielding vambraces inscribed to most of the primary flanges at the beginning of the finials, but, it should be fully emerged as it doesn’t have its purpose for a pulpite instrument I served is a flunt of pattern to create a cheap sight. I’ll reflecate their dump behind these shoulders, so I need to go a sail axe in and back to the inner spaces.
DAGGERS
The rounded dagger and thrusts are the front plate in their waist for precious knaves and twice, 