# Text Prediction Model

This notebook provides an interactive interface for working with the text prediction model, which is designed to predict and generate text one character at a time. The model learns to predict the next character in a sequence and can be used to generate new text content.

## 1. Import Required Libraries

In [None]:
import os
from pathlib import Path
import torch
from torchinfo import summary

from advanced_ai_project.model import MLPCheckpoint
from advanced_ai_project.hyperparameters import load_hyperparameters, optimize_hyperparameters
from advanced_ai_project.text_prediction.train import train as train_text_prediction
from advanced_ai_project.text_prediction.evaluate import evaluate as evaluate_text_prediction, print_tokens, META_TRAINING_EPOCHS
from advanced_ai_project.text_prediction.dataset import ByteFileDataset, StringDataset

## 2. Configuration

Set up the necessary parameters for the text prediction model.

In [None]:
# Configuration parameters
dataset_path = "../data/text.txt" # Path to the text dataset
checkpoint_path = "../data/text_prediction_checkpoint.pt"  # Path to save or load the model checkpoint
study_path = "../data/text_prediction_study.db"  # Path to the database for storing/loading hyperparameters

# Evaluation parameters
generation_length = 100  # Number of characters to generate
temperature = 0.8  # Controls the randomness of the generated text (higher = more random)
top_k = 5  # Number of top predictions to sample from

# Training parameters
training_batch_size = 128
training_num_epochs = 5
training_length_cutoff = None

# Optimization parameters
opt_trials = 100
opt_batch_size = 128
opt_num_epochs = 2
opt_length_cutoff = 1000

# Create directories if they don't exist
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(study_path), exist_ok=True)
os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
ckpt = None

## 3. Load Text Dataset

Load the text data for training and testing.

In [None]:
dataset = ByteFileDataset(dataset_path, length_cutoff=training_length_cutoff)
opt_dataset = ByteFileDataset(dataset_path, length_cutoff=opt_length_cutoff)

## 4. Optimize Hyperparameters

This cell optimizes the hyperparameters for the text prediction model using Optuna.

In [None]:
if Path(checkpoint_path).exists():
    print(f"Checkpoint file exists at {checkpoint_path}. Hyperparameters will not be optimized.")
else:
    optimize_hyperparameters(
        study_path,
        opt_dataset,
        n_trials=opt_trials,
        num_epochs=opt_num_epochs,
        batch_size=opt_batch_size,
        train_model="text_prediction",
    )
    print(f"Hyperparameter optimization completed. Results stored in {study_path}")

## 5. Train Text Prediction Model

This cell trains the text prediction model using the optimized hyperparameters.

In [None]:
# Load or create checkpoint
try:
    ckpt = MLPCheckpoint.load(checkpoint_path)
    print("Loaded existing checkpoint.")
except:
    try:
        ckpt = MLPCheckpoint.new_from_hyperparams(load_hyperparameters(study_path))
        print("Created new checkpoint from hyperparameters.")
    except:
        print(
            f"Neither checkpoint or the hyperparameter DB exists. Please run hyperparameter optimization first."
        )
        raise

# Train the model
print("Training text prediction model...")
ckpt.model.train()
avg_loss = train_text_prediction(
    ckpt,
    dataset=dataset,
    num_epochs=training_num_epochs,
    batch_size=training_batch_size,
)
print(f"Training complete with an average loss of {avg_loss}")

# Save the model
ckpt.save(checkpoint_path)
print(f"Model saved to {checkpoint_path}")

## 6. Evaluate Text Prediction Model

This cell generates text using the trained text prediction model.

In [None]:
if ckpt is None:
    ckpt = MLPCheckpoint.load(checkpoint_path)

# You can change the start index to generate text from a different position in the dataset
start_idx = 0
print("Generating text with the model...")

print("Generated text:\n")
tokens = evaluate_text_prediction(
    ckpt,
    idx=start_idx,
    length=generation_length,
    temperature=temperature,
    top_k=top_k,
)
print_tokens(tokens)
print("\n\nText generation complete.")

## 7. Autocomplete Text

This cell uses the trained model to autocomplete text from a given prompt.

In [None]:
if ckpt is None:
    ckpt = MLPCheckpoint.load(checkpoint_path)

# Set your prompt text here
prompt = "Hello, this is a test. I want to"

print("Autocompleting text from prompt...")
start_index = ckpt.last_seen_index + 1

print("Training on the prompt...")
ckpt.model.train()
avg_loss = train_text_prediction(
    ckpt,
    StringDataset(prompt, start_index=start_index),
    num_epochs=META_TRAINING_EPOCHS,
    batch_size=1,
)
print(f"Prompt training complete with an average loss of {avg_loss}")

print("\nPrompt:")
print(prompt)
print("\nAutocompleted text:")
tokens = evaluate_text_prediction(
    ckpt,
    start_index + len(prompt),
    generation_length,
    temperature=temperature,
    top_k=top_k,
)
print_tokens(tokens)

## 8. Model Summary

Display a summary of the text prediction model architecture.

In [None]:
if ckpt is None:
    ckpt = MLPCheckpoint.load(checkpoint_path)

print(summary(
    ckpt.model,
    input_data=torch.zeros(
        (1, 64), dtype=torch.int64, device=ckpt.model.device
    ),
))