In [12]:
import torch.nn as nn
import torch
from models.pale_nets import BiggerNet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = BiggerNet().to(device)


In [13]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# function to loop through the test dataset and count how many predictions we get correct
def test_accuracy(net, test_data):
    correct = 0
    for i in range(len(test_data)):
        image, label = test_data[i]
        output = net(image.to(device))
        prediction = output.argmax()
        if (prediction == label):
            correct += 1
    return correct/len(test_data)*100

In [14]:
import torch

# download and load the MNIST training dataset
mnist_training = datasets.MNIST(root='./data', download=True, transform=transforms.ToTensor())

# create a training data loader
train_loader = torch.utils.data.DataLoader(mnist_training, batch_size=64, shuffle=True, num_workers=1)

# create an optimizer object
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

# create a loss function object
loss_function = nn.CrossEntropyLoss()

# run the main training loop
for epoch in range(50):
    for image_batch, target_batch in train_loader:
        optimizer.zero_grad()
        output = net(image_batch.to(device))
        loss = loss_function(output, target_batch.to(device))
        loss.backward()
        optimizer.step()
    print('Epoch %d: %.4f' % (epoch+1, test_accuracy(net, mnist_test)/100))

Epoch 1: 0.8336
Epoch 2: 0.8889
Epoch 3: 0.9051
Epoch 4: 0.9103
Epoch 5: 0.9195
Epoch 6: 0.9250
Epoch 7: 0.9290
Epoch 8: 0.9334
Epoch 9: 0.9370
Epoch 10: 0.9392
Epoch 11: 0.9425
Epoch 12: 0.9458
Epoch 13: 0.9476
Epoch 14: 0.9499
Epoch 15: 0.9526
Epoch 16: 0.9559
Epoch 17: 0.9570
Epoch 18: 0.9581
Epoch 19: 0.9608
Epoch 20: 0.9614
Epoch 21: 0.9619
Epoch 22: 0.9638
Epoch 23: 0.9640
Epoch 24: 0.9658
Epoch 25: 0.9658
Epoch 26: 0.9674
Epoch 27: 0.9683
Epoch 28: 0.9689
Epoch 29: 0.9697
Epoch 30: 0.9707
Epoch 31: 0.9714
Epoch 32: 0.9722
Epoch 33: 0.9723
Epoch 34: 0.9722
Epoch 35: 0.9728
Epoch 36: 0.9741
Epoch 37: 0.9734
Epoch 38: 0.9743
Epoch 39: 0.9752
Epoch 40: 0.9756
Epoch 41: 0.9760
Epoch 42: 0.9762
Epoch 43: 0.9766
Epoch 44: 0.9776
Epoch 45: 0.9769
Epoch 46: 0.9776
Epoch 47: 0.9780
Epoch 48: 0.9776
Epoch 49: 0.9774
Epoch 50: 0.9788


In [15]:
torch.save(net.state_dict(), 'models/bigger_net.pth')
