In [None]:
import glob
import os

import torch

from model import GPT
import torch.multiprocessing as mp

In [None]:
def read_file(filepath):
    with open(filepath, "r", encoding="utf-8") as file:
        return file.read() + "\n"


def read_all_files_to_string(directory):
    filepaths = [
        filepath
        for filepath in glob.glob(os.path.join(directory, "**", "*"), recursive=True)
        if os.path.isfile(filepath)
    ]

    if not filepaths:
        raise ValueError("No files found in the input directory.")

    combined_string = ""
    with mp.Pool(min(len(filepaths), mp.cpu_count())) as executor:
        results = executor.map(read_file, filepaths)
        combined_string = "".join(results)

    return combined_string


def prepare_data(text: str):
    if not text:
        raise ValueError(
            "The input text is empty. Please check the file reading process."
        )

    lines = text.splitlines()
    lines = [line for line in lines if all(c.isascii() for c in line)]

    if not lines:
        raise ValueError("No valid ASCII lines found in the input text.")

    chars = sorted(list(set("".join(lines[:100_000]))))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}

    def encode(string):
        return [stoi[c] for c in string]

    def decode(tokens):
        return "".join([itos[i] for i in tokens])

    encoded_lines = [torch.tensor(encode(line), dtype=torch.long) for line in lines]

    if not encoded_lines:
        raise ValueError("No lines were encoded. Check the encoding process.")

    return torch.cat(encoded_lines), encode, decode, vocab_size


In [None]:
def get_batch(data, block_size, batch_size):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y


In [None]:
@torch.no_grad()
def estimate_loss(
    train_data,
    val_data,
    eval_interval,
    block_size,
    batch_size,
    model,
):
    out = {}
    mapping = {"train": train_data, "val": val_data}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_interval)
        for k in range(eval_interval):
            X, Y = get_batch(mapping[split], block_size, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [None]:
def train_model(
    model,
    train_data,
    val_data,
    block_size,
    batch_size,
    learning_rate,
    max_epochs,
    eval_interval,
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(max_epochs):
        x_batch, y_batch = get_batch(train_data, block_size, batch_size)
        _, train_loss = model(x_batch, y_batch)
        optimizer.zero_grad(set_to_none=True)
        train_loss.backward()
        optimizer.step()

        output = f"Epoch {epoch}: Train loss {train_loss:.4f}"

        if epoch % 10 == 0:
            x_batch, y_batch = get_batch(val_data, block_size, batch_size)
            _, val_loss = model(x_batch, y_batch)
            output += f", Val loss {val_loss:.4f}"

        print(output)


In [None]:
batch_size = 128
block_size = 128
max_epochs = 100
learning_rate = 3e-4
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
save_every = 10
device = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
text = read_all_files_to_string("shakespeare")
train_data, encode, decode, vocab_size = prepare_data(text)

In [None]:
model = GPT(
    vocab_size,
    n_embd,
    block_size,
    n_layer,
    n_head,
    dropout,
).to(device)


In [None]:
model.load_state_dict(torch.load("snapshot.pt")["MODEL_STATE"])
model.eval()

In [None]:
train_model(
    model,
    train_data,
    block_size,
    batch_size,
    learning_rate,
    max_epochs,
)

In [None]:
context = torch.tensor([encode("KEKW ")], dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=100)[0].tolist()))


In [None]:
# list the number of total parameters in the model
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {total_params}")

In [None]:
torch.save(model.state_dict(), 'model.pt')