# MNIST Neural Network Training Notebook

This notebook demonstrates how to build, train, and evaluate a neural network for MNIST digit recognition from scratch. The code is modular, well-commented, and includes visualization of the training progress.


In [2]:
# Adjust the Python path to include the project root if necessary

import sys
import os
# sys.path.append(os.path.join(os.path.dirname(__file__), '..'))

import numpy as np
import matplotlib.pyplot as plt

# Import dataset module and neural network components
from src.dataset.dataset import load_and_preprocess_mnist
from src.neuralnet.network import NeuralNetwork
from src.neuralnet.layers import Dense
from src.neuralnet.activations import relu, relu_derivative, softmax
from src.neuralnet.losses import cross_entropy_loss, cross_entropy_loss_derivative
from src.neuralnet.utils import one_hot_encode


## Load and Preprocess the MNIST Dataset

We load the raw MNIST data, preprocess the images (normalization and reshaping), and split the data into training, validation, and test sets. Labels are one-hot encoded for use with the cross-entropy loss.


In [3]:
# Define the data directory (adjust the path as needed)
data_dir = os.path.join(os.getcwd(), "../data/mnist")

# Load dataset using the provided module
train_images_all, train_labels_all, test_images, test_labels = load_and_preprocess_mnist(data_dir)

# Split training data into training and validation sets (90% training, 10% validation)
split_index = int(0.9 * train_images_all.shape[0])
train_images = train_images_all[:split_index]
train_labels = train_labels_all[:split_index]
val_images = train_images_all[split_index:]
val_labels = train_labels_all[split_index:]

# One-hot encode the labels for training, validation, and test sets
num_classes = 10
train_labels_encoded = one_hot_encode(train_labels, num_classes)
val_labels_encoded = one_hot_encode(val_labels, num_classes)
test_labels_encoded = one_hot_encode(test_labels, num_classes)

print("Training images shape:", train_images.shape)
print("Validation images shape:", val_images.shape)
print("Test images shape:", test_images.shape)


Training images shape: (54000, 784)
Validation images shape: (6000, 784)
Test images shape: (10000, 784)
