In [262]:
import torch
import torch.nn as nn
import torch.optim as optim
import random


In [263]:
with open("../data/wizard_of_oz.txt", "r", encoding="utf-8") as f:
    text = f.read()

print(len(text))
print(text[:300])

chars = sorted(set(text))
vocab_size = len(chars)
print(chars)

232309
﻿  DOROTHY AND THE WIZARD IN OZ

  BY

  L. FRANK BAUM

  AUTHOR OF THE WIZARD OF OZ, THE LAND OF OZ, OZMA OF OZ, ETC.

  ILLUSTRATED BY JOHN R. NEILL

  BOOKS OF WONDER WILLIAM MORROW & CO., INC. NEW YORK


  [Illustration]


  COPYRIGHT 1908 BY L. FRANK BAUM

  ALL RIGHTS RESERVED


         *    
['\n', ' ', '!', '"', '&', "'", '(', ')', '*', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', '[', ']', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '\ufeff']


In [264]:
def encode(text: str):
    ans = []
    for t in text:
        ans.append(chars.index(t))
    return ans

encoded_hello = encode("hello")
encoded_hello

[61, 58, 65, 65, 68]

In [265]:
def decode(indices: list):
    ans = ""
    for i in indices:
        ans += chars[i]
    return ans

decoded_hello = decode(encoded_hello)
decoded_hello

'hello'

In [266]:
encoded_text = encode(text)
print(encoded_text[:100])
bigram_pairs = [(encoded_text[i], encoded_text[i + 1]) for i in range(len(encoded_text) - 1)]
bigram_pairs

[80, 1, 1, 28, 39, 42, 39, 44, 32, 49, 1, 25, 38, 28, 1, 44, 32, 29, 1, 47, 33, 50, 25, 42, 28, 1, 33, 38, 1, 39, 50, 0, 0, 1, 1, 26, 49, 0, 0, 1, 1, 36, 11, 1, 30, 42, 25, 38, 35, 1, 26, 25, 45, 37, 0, 0, 1, 1, 25, 45, 44, 32, 39, 42, 1, 39, 30, 1, 44, 32, 29, 1, 47, 33, 50, 25, 42, 28, 1, 39, 30, 1, 39, 50, 9, 1, 44, 32, 29, 1, 36, 25, 38, 28, 1, 39, 30, 1, 39, 50]


[(80, 1),
 (1, 1),
 (1, 28),
 (28, 39),
 (39, 42),
 (42, 39),
 (39, 44),
 (44, 32),
 (32, 49),
 (49, 1),
 (1, 25),
 (25, 38),
 (38, 28),
 (28, 1),
 (1, 44),
 (44, 32),
 (32, 29),
 (29, 1),
 (1, 47),
 (47, 33),
 (33, 50),
 (50, 25),
 (25, 42),
 (42, 28),
 (28, 1),
 (1, 33),
 (33, 38),
 (38, 1),
 (1, 39),
 (39, 50),
 (50, 0),
 (0, 0),
 (0, 1),
 (1, 1),
 (1, 26),
 (26, 49),
 (49, 0),
 (0, 0),
 (0, 1),
 (1, 1),
 (1, 36),
 (36, 11),
 (11, 1),
 (1, 30),
 (30, 42),
 (42, 25),
 (25, 38),
 (38, 35),
 (35, 1),
 (1, 26),
 (26, 25),
 (25, 45),
 (45, 37),
 (37, 0),
 (0, 0),
 (0, 1),
 (1, 1),
 (1, 25),
 (25, 45),
 (45, 44),
 (44, 32),
 (32, 39),
 (39, 42),
 (42, 1),
 (1, 39),
 (39, 30),
 (30, 1),
 (1, 44),
 (44, 32),
 (32, 29),
 (29, 1),
 (1, 47),
 (47, 33),
 (33, 50),
 (50, 25),
 (25, 42),
 (42, 28),
 (28, 1),
 (1, 39),
 (39, 30),
 (30, 1),
 (1, 39),
 (39, 50),
 (50, 9),
 (9, 1),
 (1, 44),
 (44, 32),
 (32, 29),
 (29, 1),
 (1, 36),
 (36, 25),
 (25, 38),
 (38, 28),
 (28, 1),
 (1, 39),
 (39, 30),
 (30

In [267]:
from torch.utils.data import Dataset, DataLoader, random_split

total_samples = len(bigram_pairs)
train_ratio = 0.8
val_ratio = 0.1
test_ratio = 0.1

train_size = int(train_ratio * total_samples)
val_size = int(val_ratio * total_samples)
test_size = total_samples - train_size - val_size

train_data, val_data, test_data = random_split(bigram_pairs, [train_size, val_size, test_size])
# print(train_data[10])
batch_size = 64

# Create DataLoader instances
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

print(len(train_loader), len(val_loader), len(test_loader))

2904 363 363


In [268]:
class BigramModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)

    def forward(self, x):
        output = self.embeddings(x)
        output = self.linear(output)
        return output
    
    def generate(self, start_char, num_chars_to_generate):
        generated_chars = [start_char]
        for _ in range(num_chars_to_generate):
            x = torch.tensor([generated_chars[-1]])
            y = self.forward(x)
            p = nn.functional.softmax(y, dim=1)
            c = torch.multinomial(p, num_samples=1).item()
            print(c)
            generated_chars.append(c)
        return generated_chars
    
model = BigramModel(vocab_size, 32)

In [269]:
import time

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters())

start_time = time.time()

num_epochs = 10

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    for batch, (X, y) in enumerate(train_loader):
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    model.eval()
    val_loss = 0.0
    for batch, (X, y) in enumerate(val_loader):
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        val_loss += loss.item()

    val_loss /= len(val_loader)

    print(f"Epoch: {epoch + 1}/{num_epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")

end_time = time.time()

print(f"Training time: {end_time - start_time:.2f}s")

Epoch: 1/10 | Train Loss: 2.6964 | Val Loss: 2.4940
Epoch: 2/10 | Train Loss: 2.4751 | Val Loss: 2.4679
Epoch: 3/10 | Train Loss: 2.4587 | Val Loss: 2.4602
Epoch: 4/10 | Train Loss: 2.4518 | Val Loss: 2.4571
Epoch: 5/10 | Train Loss: 2.4479 | Val Loss: 2.4546
Epoch: 6/10 | Train Loss: 2.4455 | Val Loss: 2.4517
Epoch: 7/10 | Train Loss: 2.4438 | Val Loss: 2.4517
Epoch: 8/10 | Train Loss: 2.4428 | Val Loss: 2.4498
Epoch: 9/10 | Train Loss: 2.4420 | Val Loss: 2.4509
Epoch: 10/10 | Train Loss: 2.4412 | Val Loss: 2.4484
Training time: 13.70s


In [270]:
start_char = "a"
num_chars_to_generate = 100

generated_chars = model.generate(chars.index(start_char), num_chars_to_generate)

print(decode(generated_chars))

tensor([54])
tensor([[  0.4476,   3.0801,  -1.0999,  -3.7089,  -7.7807,  -1.5087,  -7.2848,
          -4.4094,  -6.9519,   0.1063,  -2.1395,  -1.3441,  -5.9838,  -6.2065,
          -7.7180,  -7.0844,  -5.5750,  -6.8737,  -7.0499,  -7.2669,  -7.4356,
          -7.2925,  -4.7634,  -1.8865,  -2.1493,  -7.0015,  -5.8457,  -6.9178,
          -7.1793,  -7.8557,  -7.0674,  -6.5126,  -6.0109,  -6.9641,  -6.0610,
          -6.9597,  -5.6478,  -6.2360,  -5.3907,  -5.8922,  -6.0547,  -6.5954,
          -4.9939,  -6.4394,  -6.3957,  -5.2390,  -6.4655,  -5.1749,  -7.0813,
          -5.5876,  -5.0895,  -6.7117,  -5.4861,  -4.5336,  -1.9318,   1.6557,
           2.2706,   2.5641,  -1.2145,   0.6423,   1.9042,  -1.6944,   2.7266,
          -3.7989,   1.3663,   3.0986,   1.7064,   4.0573,  -2.3003,   1.3567,
          -4.3120,   3.3967,   3.3245,   3.4168,   1.0627,   1.9415,   1.2655,
          -2.2714,   1.7416,  -0.6600, -13.1981]], grad_fn=<AddmmBackward0>)
torch.Size([1, 81])
75
tensor([75])
tenso