## prepare & process data

In [27]:
import data
import util

filenames_list = ["data/alice.txt", "data/harry_potter_01.txt", "data/harry_potter_02.txt"]

for filename in filenames_list:
    data.clean_text(filename)

data/alice.txt_cleaned.txt 143224 characters
data/harry_potter_01.txt_cleaned.txt 454252 characters
data/harry_potter_02.txt_cleaned.txt 509083 characters


In [22]:
text = "Harry Potter was a wizard."

tokens = util.tokenizer.encode(text)

print("characters:", len(text), "tokens", len(tokens))
print(f"{tokens} -> {util.tokenizer.decode(tokens)}")
for t in tokens:
    print(f"{t}\t -> {util.tokenizer.decode([t])}")

characters: 26 tokens 6
[18308, 14179, 373, 257, 18731, 13] -> Harry Potter was a wizard.
18308	 -> Harry
14179	 ->  Potter
373	 ->  was
257	 ->  a
18731	 ->  wizard
13	 -> .


In [23]:
with open("data/alice.txt_cleaned.txt", 'r', encoding='utf-8-sig') as file: # remove BOM with -sig
    txt = file.read()

dataset = data.MyDataset(txt, max_length = 32, stride = 4)

# of tokens in txt: 35323


In [24]:
i = 0
print(f"{i}\ninput: {util.tokenizer.decode(dataset[i][0].tolist())}\ntarget: {util.tokenizer.decode(dataset[i][1].tolist())}")

i = 1
print(f"{i}\ninput: {util.tokenizer.decode(dataset[i][0].tolist())}\ntarget: {util.tokenizer.decode(dataset[i][1].tolist())}")

0
input: Project Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever. You
target:  Gutenberg's Alice's Adventures in Wonderland, by Lewis Carroll This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever. You may
1
input: 's Adventures in Wonderland, by Lewis Carroll This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever. You may copy it,
target:  Adventures in Wonderland, by Lewis Carroll This eBook is for the use of anyone anywhere at no cost and with almost no restrictions whatsoever. You may copy it, give


In [25]:
train_loader = data.DataLoader(dataset, batch_size=128, shuffle=True, drop_last=True)

In [26]:
data_iter = iter(train_loader)

x, y = next(data_iter)
print(f"{util.tokenizer.decode(x[0].tolist())}\n{util.tokenizer.decode(y[0].tolist())}")

x, y = next(data_iter)
print(f"{util.tokenizer.decode(x[0].tolist())}\n{util.tokenizer.decode(y[0].tolist())}")

 word you fellows were saying.' 'Tell us a story!' said the March Hare. 'Yes, please do!' pleaded Alice. 'And be quick about it
 you fellows were saying.' 'Tell us a story!' said the March Hare. 'Yes, please do!' pleaded Alice. 'And be quick about it,'
 now,' she said, by way of keeping up the conversation a little. ''Tis so,' said the Duchess: 'and the moral of that is--
,' she said, by way of keeping up the conversation a little. ''Tis so,' said the Duchess: 'and the moral of that is--"


## train

In [33]:
import data
import util
import transformer
import torch

device = util.device
print(device)

torch.manual_seed(123)
model = transformer.GPTModel()
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0004, weight_decay=0.1)

tokens_seen, global_step, n_epochs = 0, -1, 100

cpu


In [None]:
losses = []

train_loader = data.get_train_loader("data/alice.txt_cleaned.txt")

for epoch in range(n_epochs):
    model.train()  # Set model to training mode
    
    epoch_loss = 0
    for input_batch, target_batch in train_loader:
        optimizer.zero_grad() # Reset loss gradients from previous batch iteration
        input_batch, target_batch = input_batch.to(device), target_batch.to(device)

        logits = model(input_batch)
        loss = torch.nn.functional.cross_entropy(logits.flatten(0, 1), target_batch.flatten())
        epoch_loss += loss.item()
        loss.backward() # Calculate loss gradients
        optimizer.step() # Update model weights using loss gradients
        tokens_seen += input_batch.numel()
        global_step += 1

        if global_step % 1000 == 0:
            print(f"Tokens seen: {tokens_seen}")
        # Optional evaluation step

    avg_loss = epoch_loss / len(train_loader)
    losses.append(avg_loss)
    print(f"Epoch: {epoch + 1}, Loss: {avg_loss}")
    torch.save(model.state_dict(), "model_" + str(epoch + 1).zfill(3) + ".pth")


In [None]:
import matplotlib.pyplot as plt

plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.show()

## play with trained model

In [50]:
import torch

def generate(model, idx, max_new_tokens, context_size, temperature=0.0, top_k=None, eos_id=None):

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]

        if top_k is not None:
            top_logits, _ = torch.topk(logits, top_k)
            min_val = top_logits[:, -1]
            logits = torch.where(logits < min_val, torch.tensor(float("-inf")).to(logits.device), logits)

        if temperature > 0.0:
            logits = logits / temperature
            probs = torch.softmax(logits, dim=-1)  # (batch_size, context_len)
            idx_next = torch.multinomial(probs, num_samples=1)  # (batch_size, 1)
        else:
            idx_next = torch.argmax(logits, dim=-1, keepdim=True)  # (batch_size, 1)

        if idx_next == eos_id:
            break

        idx = torch.cat((idx, idx_next), dim=1)  # (batch_size, num_tokens+1)

    return idx


In [40]:
import util
import transformer

torch.manual_seed(123)
model = transformer.GPTModel()
model.to(util.device)
model.load_state_dict(torch.load("model_snapshot/alice_txt_clean_txt_model_100.pth", map_location=device, weights_only=True))
model.eval() # q: this skips dropout?

context_size = model.pos_emb.weight.shape[0] 

In [46]:
import util

start_context = input("Start context: ")

idx = util.tokenizer.encode(start_context)
idx = torch.tensor(idx).unsqueeze(0)
device = util.device

for i in range(5):
    token_ids = generate(
        model=model,
        idx=idx.to(device),
        max_new_tokens=50,
        context_size= context_size,
        top_k=50,
        temperature=0.5
    )

    flat = token_ids.squeeze(0) # remove batch dimension
    out = util.tokenizer.decode(flat.tolist()).replace("\n", " ")

    print(i, ":", out)

Start context:  We're all


0 : We're all mad here. I'm mad. You're mad.' 'How do you know I'm mad?' said Alice. 'You must be,' said the Cat, 'You must be,' said the Cat, 'or you might bite,' said the Cat
1 : We're all mad here. I'm mad. You're mad.' 'How do you know I'm mad?' said Alice. 'You must be,' said the Cat, 'You must be,' said the Cat, 'or you might bite,' said the Cat
2 : We're all mad here. I'm mad. You're mad.' 'How do you know I'm mad?' said Alice. 'You must be,' said the Cat, 'You must be,' said the Cat, 'or you can't talk about, '
3 : We're all wrong.' 'Yes, but I grow at a reasonable pace,' said the Dormouse: 'not in that ridiculous fashion.' And he got up very grave such a minute or two, But she added in to the time he got up very grave
4 : We're all wrong!' cried the Mock Turtle, capering wildly about. 'Change lobsters again!' yelled the Gryphon at the top of its voice. 'Back to itself, and the Mock Turtle. 'Back to its voice. 'Back to itself


In [58]:
import util
import transformer

torch.manual_seed(123)
model_harry_potter_02 = transformer.GPTModel()
model_harry_potter_02.to(util.device)
model_harry_potter_02.load_state_dict(torch.load("model_snapshot/harry_potter_02_txt_cleaned_txt_model_step_15000.pt", map_location=device, weights_only=True))
model_harry_potter_02.eval() # q: this skips dropout?

context_size = model_harry_potter_02.pos_emb.weight.shape[0] 

In [59]:
import util

start_context = input("Start context: ")

idx = util.tokenizer.encode(start_context)
idx = torch.tensor(idx).unsqueeze(0)
device = util.device

for i in range(5):
    token_ids = generate(
        model=model_harry_potter_02,
        idx=idx.to(device),
        max_new_tokens=50,
        context_size= context_size,
        top_k=50,
        temperature=0.5
    )

    flat = token_ids.squeeze(0) # remove batch dimension
    out = util.tokenizer.decode(flat.tolist()).replace("\n", " ")

    print(i, ":", out)

Start context:  gryffindor


0 : gryffindor, I suggest you look more closely at this.â€ Dumbledore reached across to Professor McGonagallâ€™s desk, looking around to say, but thought Harryâ€™s desk, when he kicked out of the bird looked
1 : gryffindor Tower, desperate to tell Ron and Hermione about Colin and Dobby, but they werenâ€™t there. Harry left to look for them so far too many questions for them, wondering if only when heâ€ He pulled him out
2 : gryffindor, I suggest you look more closely at this.â€ Dumbledore reached across to Professor McGonagallâ€™s desk, picked up the desk, picked up the hat, picked up the hat, the hat next to Hedwigâ
3 : gryffindor, occasionally coming to long enough to copy down a name or date, then falling asleep again. He had been speaking for half an hour when something â€™s idea if it in case you know about six oâ€”â€
4 : gryffindor Tower. The castle was quiet; it seemed that the feast was over. They walked past muttering portraits and creaking suits of armor, and climbed narro