# Using a Pre-Trained Model

In [None]:
from random import randint

import torch
from matplotlib.axes import Axes
from matplotlib.pyplot import figure
from numpy import loadtxt
from numpy import ndarray
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.datasets import VisionDataset
from torchvision.transforms import ToTensor
from tqdm import tqdm

from activations import relu
from devices import Device
from networks import FeedForwardNetwork
from networks import NeuralNetwork
from utils import DATA_PATH

## Loading Model Parameters from Disk

In a neural network, parameters constitute the weights and biases between units. They're essential because they are learned and adjusted during the training phase, allowing the model to make accurate predictions or classifications.

For our task, we've previously trained a model on the MNIST dataset and saved its parameters to disk. Let's load these parameters and initialize our neural network model for further predictions.

In [None]:
def mnist_model() -> FeedForwardNetwork:
    """Load the MNIST model"""

    # Load weights and biases for hidden layers
    weights = [torch.from_numpy(loadtxt(DATA_PATH / f"W{i}.txt")).float() for i in (1, 2)]
    biases = [torch.from_numpy(loadtxt(DATA_PATH / f"b{i}.txt")).float() for i in (1, 2)]

    # Load weights and biases for the output layer
    output_weights = torch.from_numpy(loadtxt(DATA_PATH / "U.txt")).float()
    output_biases = torch.from_numpy(loadtxt(DATA_PATH / "c.txt")).float()

    # Initialize the neural network model
    return FeedForwardNetwork(
        n_features=784,
        hidden_layer_sizes=[16, 16],
        activation_functions=[relu, relu],
        n_classes=10,
    ).load_parameters(weights, output_weights, biases, output_biases)

MNIST_MODEL = mnist_model()
# Display the initialized model
print(MNIST_MODEL)

## Making Predictions with the Model

After successfully loading the model and dataset, it's time to use our model to make predictions. We'll select a subset of images from the MNIST dataset at random and visualize their true and predicted class labels.

In [None]:
N_EXAMPLES = 3
DATASET = MNIST(str(DATA_PATH / "mnist"), train=False, transform=ToTensor(), download=True)

def setup_figure() -> ndarray[Axes]:
    """Set up a figure for visualization"""
    fig = figure(figsize=(6, N_EXAMPLES * 3))
    grid_spec = fig.add_gridspec(1, 3, hspace=2)
    axs = grid_spec.subplots(sharey='row')
    return axs

def display_predictions(*axs: Axes) -> None:
    """
    Visualize predictions for a subset of the MNIST dataset.

    This function randomly selects images from the MNIST dataset, makes predictions using
    the loaded model, and visualizes the images alongside their true and predicted class labels.

    :param axs: A list of Axes objects to display the images and predictions.
    """
    for i in range(N_EXAMPLES):
        idx = randint(0, len(DATASET))
        img, true_label = DATASET[idx]
        img_view = img.view(28, 28).numpy()

        # Predicting the class label using the model
        pred_prob, pred_label = torch.max(MNIST_MODEL(img.view(1, 784)), dim=1)

        # Displaying the image along with true and predicted labels
        axs[i].imshow(img_view)
        axs[i].set_title(
            f"True Class: {true_label}\n"
            f"Predicted Class: {pred_label.item()}\n"
            f"Confidence: {pred_prob.item():.2f}"
        )


# Using the function to display predictions
display_predictions(*setup_figure())

## Evaluating Model Performance

We want to see how well our model does on all the pictures we have. We'll use a function that goes through all the images in small groups, makes guesses with our model, and then counts how many guesses are right. In the end, we'll see what percentage of the guesses were correct.

In [None]:
def evaluate_network(
        network: NeuralNetwork,
        dataset: VisionDataset,
        batch_size: int = 100,
        device: Device = Device.CPU,
):
    """
    Evaluates the performance of a neural network on a given vision dataset.

    This function iterates over the dataset using batches, computes predictions for each
    batch using the provided network, and tracks the number of correct predictions.
    At the end, it prints the accuracy of the network on the dataset.

    :param network: The neural network model to be evaluated.
    :param dataset: The dataset on which the network is evaluated.
    :param batch_size: The size of the batches in which the dataset is divided for
                       evaluation.
                       Default is 100.
    :param device: The device on which the computations are performed (CPU or GPU).

    __Note:__

    - This function assumes that the network's forward method outputs raw scores (logits)
      for each class.
    - The accuracy is computed as the percentage of correct predictions over the total
      number of samples in the dataset.
    """
    network.to(device)
    data_loader = DataLoader(dataset, batch_size=batch_size)
    n_correct = 0
    for x, y in tqdm(data_loader):
        view: torch.Tensor = x.view(-1, network.input_size).to(device)
        predictions: torch.Tensor = torch.max(network(view), dim=1)[1]
        n_correct += torch.sum(torch.eq(predictions, y.to(device))).item()

    print(f"Accuracy: {(n_correct / len(dataset) * 100):.2f}%")


evaluate_network(MNIST_MODEL, DATASET, device=Device.CPU)