### A Simple Neural Network

Now that we have some image data to feed in, we need a neural network to feed it into.  The neural network takes the image data as its input, and produces a classification as output.  The classification is the neural network's calculation of what digit it thinks the input represents.

The input is therefore 28x28 = 784 numbers, i.e. one for each pixel in the input image.  The output of the neural network is a set of 10 numbers, where each number represents the probability of the image being that digit.

#### Structure of the Network

We are going to build a really simple network to start off with.  It is going to have 3 layers.  The input layer has 784 nodes - one for each input pixel.  The output layer has 10 nodes, one for each digit.  The middle layer has 5 nodes.  Every node in thee input layer is connected to each of the five middle layer nodes, and each node in the middle layer is connected to each node in the output.  This is called a fully connected network.

In [66]:
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 5)
        self.fc2 = nn.Linear(5, 10)

    def forward(self, x):
        x = x.view(784)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = SimpleNet()

Now we can test the model, by passing an image to it and seeing what it predicts.  This is probably wrong at the moment, since the model is created with random values for the node weights, and therefore the prediction is random at the moment.  So it has a 1 in 10 chance of being right, since there are 10 possible answers (digits).

In [67]:

# reference the libraries we will use
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# download and load the MNIST dataset
mnist = datasets.MNIST(root='./data', download=True)

# get the first image and its label
# change the index number here to show other images and labels
image, label = mnist[0]

# convert the image to a tensor - ie a list of numbers suitable for input into a neural network
# this converts the image from a PIL image to a PyTorch tensor, and also scales the pixel values from 0-255 to 0-1
tensor = transforms.ToTensor()(image)

# pass the image through the neural network
output = net(tensor)
print(f"Output = {output}")

Output = tensor([ 0.2345, -0.1748, -0.0267, -0.0392,  0.3167, -0.0355,  0.4001, -0.3960,
        -0.0787,  0.2393], grad_fn=<ViewBackward0>)


The output is a list of probabilities for the digits that the input might represent.  We can get the most likely digit (ie what the model thinks the image represents) by choosing the highest probability digit.

In [68]:
prediction = output.argmax()
print(f"Most likely digit = {prediction}")

Most likely digit = 6


We can get a set of data (images and labels) from the data set and use this as test data.  We can then loop through this data, asking the model to predict the answer and see what percentage we get right.  If the model is basically random, then the percentage we get right should be about 10%.  If you run the notebook a few times, you should see this percentage change a little, which shows that the model is random and different each time.

In [69]:
# get a test dataset from MNIST
# note that we have apply the transform here, to save doing it separately later
# using train=False gives us the test dataset, which contains different images from the dataset we will use for training later.
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# function to loop through the test dataset and count how many predictions we get correct
def test_accuracy(net, test_data):
    correct = 0
    for i in range(len(test_data)):
        image, label = test_data[i]
        output = net(image)
        prediction = output.argmax()
        if (prediction == label):
            correct += 1
    return correct/len(test_data)*100

print(f"Percentage of correct predictions = {test_accuracy(net, mnist_test)}%")

Percentage of correct predictions = 9.59%


If we want to improve the model (and we do) then we will have to train it.