In [None]:
with open("./wizard-of-oz.txt", encoding="utf-8-sig") as f:
    txt = f.read()

VOCAB = sorted(list(set(txt)))
VOCAB_SIZE = len(VOCAB)
encode_dict = {c:i for i, c in enumerate(VOCAB)}
decode_dict = {i:c for c, i in encode_dict.items()}

def tok_encode(text):
    return [encode_dict[c] for c in text]
def tok_decode(tok_indices):
    return "".join(decode_dict[e] for e in tok_indices)

tok_decode(tok_encode(txt)) == txt

In [None]:
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
data = torch.LongTensor(tok_encode(txt)).to(DEVICE)
split = int(len(data) * 0.9)
train_data = data[:split]
val_data = data[split:]

# Bigram Language Model


In [None]:
import torch
from torch.utils.data import DataLoader
from src import bigram
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = bigram.BigramLanguageModel(vocab_sz=VOCAB_SIZE).to(DEVICE)

num_epochs = 16
batch_size = 64
dataset = bigram.BigramDataset(txt_tensor=train_data, device=DEVICE)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

num_samples = []
samples_counter = 0
train_losses = []
val_losses = []
for num_epoch in range(num_epochs):
    for ii, (x, y) in enumerate(dataloader):
        loss = model.train_batch(x, y, optimizer)
        samples_counter += len(x)

        if (ii+1) % 1000 == 0 or ii + 1 == len(dataloader):
            with torch.no_grad():
                train_loss = model.compute_loss(train_data[:-1], train_data[1:])
                val_loss = model.compute_loss(val_data[:-1], val_data[1:])
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                num_samples.append(samples_counter)

            print("\r" * 100 + f"epoch {num_epoch}: {ii+1}/{len(dataloader)}", end="", flush=True)
print()

In [None]:
df = pd.DataFrame.from_dict({
    "num_samples": num_samples,
    "train_losses": train_losses,
    "val_losses": val_losses,
}).pivot_longer(column_names=["train_losses", "val_losses"], names_to="what")

ggplot(df) + geom_line(aes(x="num_samples", y="value", color="what"))

In [None]:
model.compute_loss(train_data[:-1], train_data[1:])

In [None]:
lut = model.lut.weight.detach().to("cpu").softmax(dim=-1)

Lektion

- str.count() zählt keine überlappenden Muster!


In [None]:
# bigram lut manuell

lut_counting = torch.zeros((VOCAB_SIZE, VOCAB_SIZE))
for c1, c2 in zip(train_data, train_data[1:]):
    lut_counting[c1, c2] += 1
# lut_counting = lut_counting.softmax(dim=-1)
lut_counting = lut_counting / lut_counting.norm(dim=1, keepdim=True, p=1)
lut_counting

In [None]:
lut[0, :]

In [None]:
lut_counting[0, :]

In [None]:
import numpy as np
idx = np.random.randint(low=0, high=VOCAB_SIZE, size=(100,))
out = ""
for ii in range(len(idx)):
    out += tok_decode([idx[ii]])

out


In [None]:
tok_indices = model.generate(torch.tensor([0], device=DEVICE), 64)
tok_decode(tok_indices)

In [None]:
model.compute_loss(val_data[:-1], val_data[1:])

In [None]:
-np.log(1/VOCAB_SIZE)

# Neural Language Model


In [None]:
from src import neural
model = neural.NeuralLanguageModel(vocab_sz=VOCAB_SIZE).to(DEVICE)

num_epochs = 16
batch_size = 64
dataset = neural.NeuralDataset(txt_tensor=train_data, device=DEVICE)
dataloader = DataLoader(dataset, shuffle=True, batch_size=batch_size)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

num_samples = []
samples_counter = 0
train_losses = []
val_losses = []
for num_epoch in range(num_epochs):
    for ii, (x, y) in enumerate(dataloader):
        loss = model.train_batch(x, y, optimizer)
        samples_counter += len(x)

        if (ii+1) % 1000 == 0 or ii + 1 == len(dataloader):
            with torch.no_grad():
                train_loss = model.compute_loss(train_data[:-1], train_data[1:])
                val_loss = model.compute_loss(val_data[:-1], val_data[1:])
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                num_samples.append(samples_counter)

            print("\r" * 100 + f"epoch {num_epoch}: {ii+1}/{len(dataloader)}", end="", flush=True)
print()