# Simple Neural Network on the Digits Dataset

This notebook uses scikit-learn's built-in **digits** dataset to train a simple neural network and evaluate its accuracy. The digits dataset contains 8×8 grayscale images of handwritten digits (0–9) and their corresponding labels.

In [None]:
# Imports
import matplotlib.pyplot as plt
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import accuracy_score

%matplotlib inline


In [None]:
# Load the digits dataset
# It contains 8x8 grayscale images and labels for digits 0-9

digits = load_digits()
X = digits.data
y = digits.target
images = digits.images

print(f"Number of samples: {X.shape[0]}")
print(f"Number of features per sample: {X.shape[1]}")
print(f"Image shape: {images[0].shape}")


## Sample digit images
Below is a small grid of example images from the dataset.

In [None]:
# Show a grid of sample digits
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
axes = axes.ravel()

for i in range(10):
    axes[i].imshow(images[i], cmap="gray")
    axes[i].set_title(f"Label: {y[i]}")
    axes[i].axis("off")

plt.tight_layout()
plt.show()


## Train-test split
We split the data into training and test sets to evaluate performance on unseen data.

In [None]:
# Split the data
X_train, X_test, y_train, y_test, images_train, images_test = train_test_split(
    X, y, images, test_size=0.2, random_state=42, stratify=y
)

print(f"Training samples: {X_train.shape[0]}")
print(f"Test samples: {X_test.shape[0]}")


## Train a simple neural network
We use `MLPClassifier`, a basic neural network implementation in scikit-learn.

In [None]:
# Train a simple neural network
mlp = MLPClassifier(
    hidden_layer_sizes=(64,),
    activation="relu",
    solver="adam",
    max_iter=300,
    random_state=42
)

mlp.fit(X_train, y_train)


## Evaluate the model
We compute the accuracy on the test set.

In [None]:
# Evaluate accuracy
predictions = mlp.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f"Test accuracy: {accuracy:.4f}")


## Visualize predictions
Below are a few test images with their true and predicted labels.

In [None]:
# Show some test images with true and predicted labels
fig, axes = plt.subplots(2, 5, figsize=(8, 4))
axes = axes.ravel()

for i in range(10):
    axes[i].imshow(images_test[i], cmap="gray")
    axes[i].set_title(f"True: {y_test[i]} | Pred: {predictions[i]}")
    axes[i].axis("off")

plt.tight_layout()
plt.show()
