In [1]:
import glob
import multiprocessing as mp
import os

import torch
from torch.distributed import destroy_process_group, init_process_group
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler

from model import GPT
from trainer import Trainer


In [None]:
def read_file(filepath):
    with open(filepath, "r", encoding="utf-8") as file:
        return file.read() + "\n"


def read_all_files_to_string(directory):
    filepaths = [
        filepath
        for filepath in glob.glob(os.path.join(directory, "**", "*"), recursive=True)
        if os.path.isfile(filepath)
    ]

    if not filepaths:
        raise ValueError("No files found in the input directory.")

    combined_string = ""
    with mp.Pool(min(len(filepaths), mp.cpu_count())) as executor:
        results = executor.map(read_file, filepaths)
        combined_string = "".join(results)

    return combined_string


In [None]:
from collections import Counter


def prepare_data(text: str, vocab_limit: int):
    if not text:
        raise ValueError(
            "The input text is empty. Please check the file reading process."
        )

    lines = text.splitlines()
    lines = [line for line in lines if all(c.isascii() for c in line)]

    if not lines:
        raise ValueError("No valid ASCII lines found in the input text.")

    word_counts = Counter(word for line in lines[:100_000] for word in line.split())

    most_common_words = [word for word, _ in word_counts.most_common(vocab_limit)]
    words = most_common_words + ["<unk>", "\n"]

    vocab_size = len(words)
    stoi = {word: i for i, word in enumerate(words)}
    itos = {i: word for i, word in enumerate(words)}

    def encode(sentence):
        tokens = [
            stoi[word] if word in stoi else stoi["<unk>"] for word in sentence.split()
        ]
        tokens.append(stoi["\n"])
        return tokens

    def decode(tokens):
        words = [itos[token] for token in tokens]
        return " ".join(words).replace(" \n", "\n")

    encoded_lines = [
        torch.tensor(encode(line), dtype=torch.long) for line in lines if encode(line)
    ]

    if not encoded_lines:
        raise ValueError("No lines were encoded. Check the encoding process.")

    return torch.cat(encoded_lines), encode, decode, vocab_size


In [3]:
class TextDataset(Dataset):
    def __init__(self, data, block_size):
        self.data = data
        self.block_size = block_size

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        x = self.data[idx : idx + self.block_size]
        y = self.data[idx + 1 : idx + self.block_size + 1]
        return x, y

In [None]:
batch_size = 128
block_size = 128
max_epochs = 2000
learning_rate = 3e-4
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2
save_interval = 10
num_groups = 3
vocab_limit = 10_000

eval_str = r"""
    Scene 1
=======
[Enter Theseus, Hippolyta, and Philostrate, with others.]


THESEUS
Now, fair Hippolyta, our nuptial hour
Draws on apace. Four happy days bring in
Another moon. But, O, methinks how slow
This old moon wanes! She lingers my desires
Like to a stepdame or a dowager
Long withering out a young man's revenue.

HIPPOLYTA
Four days will quickly steep themselves in night;
Four nights will quickly dream away the time;
And then the moon, like to a silver bow
New-bent in heaven, shall behold the night
Of our solemnities.

THESEUS  Go, Philostrate,
Stir up the Athenian youth """


: 

In [None]:
directory = "data"
snapshot_path = f"{directory.replace('/', '_')}.pt"

text = read_all_files_to_string(directory)
train_data, encode, decode, vocab_size = prepare_data(text, vocab_limit)
print(vocab_size)


2129633


In [8]:
def ddp_setup():
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    init_process_group(backend="nccl")


def load_train_objs():
    train_dataset = TextDataset(train_data, block_size)

    model = GPT(
        vocab_size,
        n_embd,
        block_size,
        n_layer,
        n_head,
        dropout,
        num_groups,
    )

    optimizer = AdamW(model.parameters(), lr=learning_rate)

    return train_dataset, model, optimizer


def prepare_dataloader(dataset: Dataset, batch_size: int):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        pin_memory=True,
        shuffle=False,
        sampler=DistributedSampler(dataset),
    )


In [None]:
ddp_setup()
dataset, model, optimizer = load_train_objs()
train_loader = prepare_dataloader(dataset, batch_size)
trainer = Trainer(model, train_loader, optimizer, save_interval, snapshot_path)
trainer.train(max_epochs)
destroy_process_group()

In [9]:
model = GPT(
    vocab_size,
    n_embd,
    block_size,
    n_layer,
    n_head,
    dropout,
    num_groups,
).cuda()

model.load_state_dict(torch.load(snapshot_path)["MODEL_STATE"])
context = torch.tensor([encode(eval_str)], dtype=torch.long).cuda()
print(decode(model.generate(context, max_new_tokens=1000)[0].tolist()))


  model.load_state_dict(torch.load(snapshot_path)["MODEL_STATE"])


 Make his trotting Northumberland, I checked my Proculeius, and credulous!
 Where one children's devil shall be Mantua,
 And the murder; and weary, find: with a pate
 breed the fatal appall bearing Sailor thou rebellion to dreamed gashes
 There was she, footing of their Lords.]
 Shall passing them mood
 [The Play conscience,
 And for he lament'st. of hath bloody heaven by our owe;
 costume.)]
 Now, but several King that way to th' swords.

 The gods my lord?
 There is his death.
 To make his most Roman gate.
 PLAYER Now if thou to
 EMILIA
 Sound a King is revenge, Now, bragging thought

 'Twere tempts all his array, he fell.
 Thy wrong together
 The fittest christened keep friend, who lodge from
 And become a ladies thy moved
 The wrongs that her defiled long show the whole, and he is!
 The most wormy devour!
 And perhaps see the breasts in what, unless we so sing'st and need.
 Have valiant--

 Not not need and the god she would you!
 Of silver faith a issue of the sky.
 And much my he