# HyperNet: A Neural Network That Generates Neural Networks

This notebook provides an interactive interface for working with the HyperNet model, which is a neural network that generates weights for another neural network. The HyperNet model is designed to create convolutional neural networks for image classification on the CIFAR-10 dataset.

## 1. Import Required Libraries

In [None]:
import os
from pathlib import Path
import torch
from torchinfo import summary

from advanced_ai_project.utils import get_cifar10_dataset
from advanced_ai_project.model import MLPCheckpoint
from advanced_ai_project.hyperparameters import load_hyperparameters, optimize_hyperparameters
from advanced_ai_project.hypernet.train import train as train_hypernet
from advanced_ai_project.hypernet.evaluate import evaluate as evaluate_hypernet

## 2. Configuration

Set up the necessary parameters for the HyperNet model.

In [None]:
# Configuration parameters
dataset_path = "../data/cifar10"  # Path to the CIFAR-10 dataset
checkpoint_path = "../data/hypernet_checkpoint.pt"  # Path to save or load the model checkpoint
study_path = "../data/hypernet_study.db"  # Path to the database for storing/loading hyperparameters

# Training parameters
training_batch_size = 256
training_num_epochs = 1000

# Optimization parameters
opt_trials = 1000
opt_batch_size = 256
opt_num_epochs = 2

# Create directories if they don't exist
os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
os.makedirs(os.path.dirname(study_path), exist_ok=True)
os.makedirs(os.path.dirname(dataset_path), exist_ok=True)
ckpt = None

## 3. Optimize Hyperparameters

This cell optimizes the hyperparameters for the HyperNet model using Optuna.

In [None]:
if Path(checkpoint_path).exists():
    print(f"Checkpoint file exists at {checkpoint_path}. Hyperparameters will not be optimized.")
else:
    optimize_hyperparameters(
        study_path,
        get_cifar10_dataset(dataset_path, train=True),
        n_trials=opt_trials,
        num_epochs=opt_num_epochs,
        batch_size=opt_batch_size,
        train_function=train_hypernet,
    )
    print(f"Hyperparameter optimization completed. Results stored in {study_path}")

## 4. Train HyperNet Model

This cell trains the HyperNet model using the optimized hyperparameters.

In [None]:
# Load or create checkpoint
try:
    ckpt = MLPCheckpoint.load(checkpoint_path)
    print("Loaded existing checkpoint.")
except:
    try:
        ckpt = MLPCheckpoint.new_from_hyperparams(load_hyperparameters(study_path))
        print("Created new checkpoint from hyperparameters.")
    except:
        print(f"Neither checkpoint or the hyperparameter DB exists. Please run hyperparameter optimization first.")
        raise

# Train the model
print("Training hypernet model...")
ckpt.model.train()
avg_loss = train_hypernet(
    ckpt,
    dataset=get_cifar10_dataset(dataset_path, train=True),
    num_epochs=training_num_epochs,
    batch_size=training_batch_size,
)
print(f"Training complete with an average loss of {avg_loss}")

# Save the model
ckpt.save(checkpoint_path)
print(f"Model saved to {checkpoint_path}")

## 5. Evaluate HyperNet Model

This cell evaluates the performance of the trained HyperNet model on the CIFAR-10 test dataset.

In [None]:
if ckpt is None:
    ckpt = MLPCheckpoint.load(checkpoint_path)

accuracy = evaluate_hypernet(
    ckpt,
    dataset=get_cifar10_dataset(dataset_path, train=False),
)
print(f"Test accuracy of the generated CNN: {accuracy:.2f}%")

## 6. Model Summary

Display a summary of the HyperNet model architecture.

In [None]:
if ckpt is None:
    ckpt = MLPCheckpoint.load(checkpoint_path)

print(summary(
    ckpt.model,
    input_data=torch.zeros(
        (1, 64), dtype=torch.int64, device=ckpt.model.device
    ),
))