Step 1: Import the necessary libraries.

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten

Step 2: Load and preprocess the MNIST dataset.

In [2]:
# Load the dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixel values to be between 0 and 1
x_train, x_test = x_train / 255.0, x_test / 255.0

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


Step 3: Build the neural network model.

In [3]:
model = Sequential([
    Flatten(input_shape=(28, 28)),  # Flatten the 28x28 images into a 1D array
    Dense(128, activation='relu'),   # Fully connected layer with 128 neurons and ReLU activation
    Dense(10, activation='softmax')  # Fully connected layer with 10 neurons (for 10 classes) and softmax activation
])


Step 4: Compile the model.

In [4]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])


Step 5: Train the model.

In [5]:
model.fit(x_train, y_train, epochs=10)


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.callbacks.History at 0x783ae8e4e500>

Step 6: Evaluate the model.

In [6]:
test_loss, test_accuracy = model.evaluate(x_test, y_test, verbose=2)
print("\nTest accuracy:", test_accuracy)


313/313 - 1s - loss: 0.0866 - accuracy: 0.9774 - 1s/epoch - 4ms/step

Test accuracy: 0.977400004863739


Step 7: Make predictions.

In [7]:
predictions = model.predict(x_test)
predictions



array([[3.2201819e-09, 1.3617325e-10, 2.3785360e-07, ..., 9.9994403e-01,
        1.4499634e-06, 7.4691690e-07],
       [2.1313151e-10, 1.1890016e-09, 9.9999994e-01, ..., 2.3987776e-20,
        2.4427169e-10, 4.2952524e-18],
       [4.8966089e-09, 9.9976420e-01, 8.2909874e-06, ..., 4.1938238e-05,
        1.8424445e-04, 6.0471619e-09],
       ...,
       [3.8053101e-17, 1.0043626e-13, 5.0916622e-16, ..., 2.1737248e-09,
        4.9187876e-09, 4.2798288e-06],
       [3.5433508e-13, 2.4991214e-15, 5.3558435e-15, ..., 1.1801013e-10,
        1.7469185e-08, 6.5176136e-14],
       [1.2254349e-09, 1.1223423e-14, 7.0884218e-09, ..., 9.1328568e-17,
        2.3662436e-10, 2.6558695e-15]], dtype=float32)