### Convolution



In [1]:
import torch.nn as nn
import torch
from models.pale_nets import ConvolvedNet
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = ConvolvedNet().to(device)


In [2]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms


#transform_set = transforms.Compose([transforms.Resize((32,32)), transforms.ToTensor(), transforms.Normalize(mean = (0.1307,), std = (0.3081,))])
transform_set = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = (0.1307,), std = (0.3081,))])
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform_set)

# function to loop through the test dataset and count how many predictions we get correct
def test_accuracy(net, test_data):
    with torch.no_grad():
        correct = 0
        for i in range(len(test_data)):
            image, label = test_data[i]
            image = image.unsqueeze(0)
            output = net(image.to(device))
            prediction = output.argmax()
            if (prediction == label):
                correct += 1
    return correct/len(test_data)*100

In [3]:
import torch

# download and load the MNIST training dataset
mnist_training = datasets.MNIST(root='./data', download=True, transform=transform_set)

# create a training data loader
train_loader = torch.utils.data.DataLoader(mnist_training, batch_size=64, shuffle=True, num_workers=1)

# create an optimizer object
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

# create a loss function object
loss_function = nn.CrossEntropyLoss()

# run the main training loop
for epoch in range(25):
    for image_batch, target_batch in train_loader:
        image_batch = image_batch.to(device)
        target_batch = target_batch.to(device)
        optimizer.zero_grad()
        output = net(image_batch)
        loss = loss_function(output, target_batch)
        loss.backward()
        optimizer.step()
    print('Epoch %d: %.4f, %.4f' % (epoch+1, test_accuracy(net, mnist_test)/100, loss.item()))

Epoch 1: 0.8640, 0.1127
Epoch 2: 0.8994, 0.0124
Epoch 3: 0.8931, 0.0835
Epoch 4: 0.8712, 0.0581
Epoch 5: 0.9412, 0.0030
Epoch 6: 0.9037, 0.0445
Epoch 7: 0.9753, 0.0233
Epoch 8: 0.8864, 0.0216
Epoch 9: 0.9235, 0.0139
Epoch 10: 0.9116, 0.0051
Epoch 11: 0.9173, 0.0250
Epoch 12: 0.9456, 0.0001
Epoch 13: 0.9273, 0.0003
Epoch 14: 0.9500, 0.0153
Epoch 15: 0.9340, 0.0078
Epoch 16: 0.9338, 0.0002
Epoch 17: 0.9495, 0.0001
Epoch 18: 0.9534, 0.0000
Epoch 19: 0.9601, 0.0000
Epoch 20: 0.9590, 0.0012
Epoch 21: 0.9436, 0.0001
Epoch 22: 0.9406, 0.0003
Epoch 23: 0.9494, 0.2074
Epoch 24: 0.9366, 0.0006
Epoch 25: 0.9434, 0.0004


In [4]:
torch.save(net.state_dict(), 'models/convolved_net_mnist.pth')
