In [1]:
import os
from collections import Counter
from configparser import ConfigParser

import torch

from data_loader import load_sample_data
from evaluate import evaluate
from inference import generate_text
from tokenizer import tokenize_text
from train import create_batches, train_model
from transformer_model import TransformerModel

In [2]:
def get_hyperparameters():
    config = ConfigParser()
    config.read("config.ini")

    hyperparameters = config["Hyperparameters"]
    num_samples = int(hyperparameters["num_samples"])
    batch_size = int(hyperparameters["batch_size"])
    seq_length = int(hyperparameters["seq_length"])
    num_epochs = int(hyperparameters["num_epochs"])
    learning_rate = float(hyperparameters["learning_rate"])
    scheduler_patience = int(hyperparameters["scheduler_patience"])
    scheduler_factor = float(hyperparameters["scheduler_factor"])
    max_vocab_size = int(hyperparameters["max_vocab_size"])
    embedding_dim = int(hyperparameters["embedding_dim"])
    ff_hidden_dim = int(hyperparameters["ff_hidden_dim"])
    num_blocks = int(hyperparameters["num_blocks"])
    initial_text = hyperparameters["initial_text"]
    max_len = int(hyperparameters["max_len"])
    temperature = float(hyperparameters["temperature"])

    return (
        num_samples,
        batch_size,
        seq_length,
        num_epochs,
        learning_rate,
        scheduler_patience,
        scheduler_factor,
        max_vocab_size,
        embedding_dim,
        ff_hidden_dim,
        num_blocks,
        initial_text,
        max_len,
        temperature,
    )

In [3]:
# 1. Load hyperparameters
(
    num_samples,
    batch_size,
    seq_length,
    num_epochs,
    learning_rate,
    scheduler_patience,
    scheduler_factor,
    max_vocab_size,
    embedding_dim,
    ff_hidden_dim,
    num_blocks,
    initial_text,
    max_len,
    temperature,
) = get_hyperparameters()

In [4]:
# 2. Load and pre-process data
sample_text = load_sample_data(num_samples=num_samples)
vocab, word_to_idx, idx_to_word = tokenize_text(sample_text)
tokens = sample_text.split()

if len(tokens) > max_vocab_size:
    # Count the frequency of each word in your corpus
    word_freqs = Counter(tokens)

    # Get the most common words up to MAX_VOCAB_SIZE
    vocab = [word for word, freq in word_freqs.most_common(max_vocab_size - 1)]

    # Add the special <UNK> token to the vocabulary
    vocab.append("<UNK>")

    # Create word_to_idx dictionary
    word_to_idx = {word: idx for idx, word in enumerate(vocab)}

    # Replace all words not in the vocabulary with <UNK>
    tokens = [word if word in word_to_idx else "<UNK>" for word in tokens]

Skipping, found downloaded files in "./wikipedia-20230701" (use force=True to force download)


In [5]:
# 3. Create batches
input_batches, target_batches = create_batches(
    tokens, word_to_idx, batch_size=batch_size, seq_length=seq_length
)

In [6]:
# 4. Initialize or load model
model_path = "model.pth"
if os.path.exists(model_path):
    print("Loading existing model...")
    model = torch.load(model_path)
else:
    print("Initializing new model...")
    model = TransformerModel(
        vocab_size=max_vocab_size,
        embedding_dim=embedding_dim,
        ff_hidden_dim=ff_hidden_dim,
        num_blocks=num_blocks,
    )

Loading existing model...


In [8]:
# Reload hyperparameters before training (in case they've changed)
(
    num_samples,
    batch_size,
    seq_length,
    num_epochs,
    learning_rate,
    scheduler_patience,
    scheduler_factor,
    max_vocab_size,
    embedding_dim,
    ff_hidden_dim,
    num_blocks,
    initial_text,
    max_len,
    temperature,
) = get_hyperparameters()

# 5. Train the model
print("Training model...")
train_model(
    model,
    vocab,
    num_epochs,
    learning_rate,
    scheduler_patience,
    scheduler_factor,
    input_batches,
    target_batches,
)

# Save the trained model
print("Saving model...")
torch.save(model, model_path)

Training model...
Starting epoch 1/10:
Batch 1/39...

KeyboardInterrupt: 

In [None]:
# 6. Evaluate the model
perplexity = evaluate(
    model,
    input_batches,
    target_batches,
    criterion=torch.nn.CrossEntropyLoss(),
    vocab_size=max_vocab_size,
)
print(f"Perplexity: {perplexity}")

Perplexity: 1.8123218504774934


In [None]:
# 7. Generate text
generated_text = generate_text(
    model,
    idx_to_word,
    word_to_idx,
    initial_text=initial_text,
    max_len=max_len,
    temperature=temperature,
)
print(f"Generated Text: {generated_text}")

Generated Text: There de de de Peso Peso de Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso Peso
