In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision
import models
import torch.nn.functional as F
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from torchvision.transforms import Normalize

from dataloader import *



# Load the Dataset (MNIST)

In [None]:
# Load the dataset
train_ds = MNIST('./data/' +"mnist", train=True, transform=Compose([ToTensor(),Normalize((0.1307),(0.3015))]),download=True)
test_ds = MNIST('./data/' +"mnist", train=False, transform=Compose([ToTensor(),Normalize((0.1307),(0.3015))]),download=True)
classes = ['0','1','2','3','4','5','6','7','8','9']

train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=4)

# Set the model
model = resnet18(num_classes=10)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
# Define Optimizer and Criterion
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)


## Training the model

In [None]:
for epoch in range(10):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data[0].to(device), data[1].to(device)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 1000 == 0:    # print every 1000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 1000))
            running_loss = 0.0

print('Finished Training')


# Accuracy on test set

In [None]:
correct = 0
total = 0

with torch.no_grad():
    for data in test_dl:
        images, labels =data[0].to(device), data[1].to(device)
        # calculate outputs by running images through the network
        outputs = model(images)
        # the class with the highest energy is what we choose as prediction
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (
    100 * correct / total))

## Display an Example

In [None]:
dataiter = iter(test_dl)
images, labels = dataiter.next()

# print images
plt.imshow(torchvision.utils.make_grid(images).permute(1, 2, 0))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))

In [None]:
outputs = model(images.to(device))
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
                              for j in range(4)))