In [1]:
import glob
import os

import torch

from models import GPTLanguageModel

In [2]:
def read_all_files_to_string(directory):
    combined_string = ""
    for filepath in glob.glob(os.path.join(directory, "**", "*"), recursive=True):
        if os.path.isfile(filepath):
            with open(filepath, "r", encoding="utf-8") as file:
                combined_string += file.read() + "\n"
    return combined_string


def prepare_data(text, device):
    lines = text.splitlines()
    lines = [line for line in lines if all(c.isascii() for c in line)]
    chars = sorted(list(set("".join(lines))))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for i, ch in enumerate(chars)}

    def encode(string):
        return [stoi[c] for c in string]

    def decode(tokens):
        return "".join([itos[i] for i in tokens])

    encoded_lines = [torch.tensor(encode(line), dtype=torch.long) for line in lines]
    data = torch.cat(encoded_lines)

    n = len(data)
    train_data = data[: int(n * 0.8)].to(device)
    val_data = data[int(n * 0.8) :].to(device)

    return train_data, val_data, encode, decode, vocab_size


In [3]:
def get_batch(data, block_size, batch_size):
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x, y


In [4]:
@torch.no_grad()
def estimate_loss(
    train_data,
    val_data,
    eval_interval,
    block_size,
    batch_size,
    model,
):
    out = {}
    mapping = {"train": train_data, "val": val_data}
    model.eval()
    for split in ["train", "val"]:
        losses = torch.zeros(eval_interval)
        for k in range(eval_interval):
            X, Y = get_batch(mapping[split], block_size, batch_size)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out


In [5]:
def train_model(
    model,
    train_data,
    val_data,
    block_size,
    batch_size,
    learning_rate,
    max_epochs,
    eval_interval,
):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

    for epoch in range(max_epochs):
        if epoch % eval_interval == 0:
            losses = estimate_loss(
                train_data,
                val_data,
                eval_interval,
                block_size,
                batch_size,
                model,
            )

            print(
                f"Epoch {epoch}: Train loss {losses['train']:.4f}, Val loss {losses['val']:.4f}"
            )

        x_batch, y_batch = get_batch(train_data, block_size, batch_size)
        logits, loss = model(x_batch, y_batch)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()


In [6]:
batch_size = 32
block_size = 512
max_epochs = 1000
eval_interval = 250
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2


In [7]:
text = read_all_files_to_string("data")
train_data, val_data, encode, decode, vocab_size = prepare_data(text, device)

model = GPTLanguageModel(
    vocab_size,
    n_embd,
    block_size,
    n_layer,
    n_head,
    device,
    dropout,
).to(device)


In [9]:
model.load_state_dict(torch.load("model.pt"))
model.eval()

  model.load_state_dict(torch.load("model.pt"))


GPTLanguageModel(
  (token_embedding_table): Embedding(95, 384)
  (position_embedding_table): Embedding(512, 384)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-5): 6 x Head(
            (key): Linear(in_features=384, out_features=64, bias=False)
            (query): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (ffwd): FeedFoward(
        (net): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1536, out_features=384, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=

In [33]:
train_model(
    model,
    train_data,
    val_data,
    block_size,
    batch_size,
    learning_rate,
    max_epochs,
    eval_interval,
)

Epoch 0: Train loss 4.6197, Val loss 4.6161
Epoch 250: Train loss 1.8211, Val loss 1.9531
Epoch 500: Train loss 1.3756, Val loss 1.5421
Epoch 750: Train loss 1.1807, Val loss 1.3612


In [61]:
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=100)[0].tolist()))


 BOBASSS ironmouseRAID FOR MOUSE CHAT ironmouseRAID FOUCK MOUNTAID ironmouseRAID EVERY COUNY ironmous


In [35]:
torch.save(model.state_dict(), 'model.pt')