In [11]:
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(784, 5)
        self.fc2 = nn.Linear(5, 10)

    def forward(self, x):
        x = x.view(-1, 784)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = SimpleNet()

In [12]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# function to loop through the test dataset and count how many predictions we get correct
def test_accuracy(net, test_data):
    correct = 0
    for i in range(len(test_data)):
        image, label = test_data[i]
        output = net(image)
        prediction = output.argmax()
        if (prediction == label):
            correct += 1
    return correct/len(test_data)*100

In [13]:
import torch

# download and load the MNIST training dataset
mnist_training = datasets.MNIST(root='./data', download=True, transform=transforms.ToTensor())

# create a training data loader
train_loader = torch.utils.data.DataLoader(mnist_training, batch_size=64, shuffle=True, num_workers=1)

# create an optimizer object
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

# create a loss function object
loss_function = nn.CrossEntropyLoss()

# run the main training loop
for epoch in range(5):
    for image_batch, target_batch in train_loader:
        optimizer.zero_grad()
        output = net(image_batch)
        loss = loss_function(output, target_batch)
        loss.backward()
        optimizer.step()
    print('Epoch %d: %.4f' % (epoch+1, test_accuracy(net, mnist_test)/100))

Epoch 1: 0.6878
Epoch 2: 0.8069
Epoch 3: 0.8421
Epoch 4: 0.8588
Epoch 5: 0.8639
