In [None]:
#Import stuff

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
from time import time
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [None]:
train_dataset = datasets.MNIST(root="./datasets", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="./datasets", train=False, transform=transforms.ToTensor(), download=True)
print("No. of Training examples: ", len(train_dataset))
print("No. of Test examples: ", len(test_dataset))
train_loader = DataLoader(train_dataset, batch_size=30, shuffle=True)
input_size = 784
hidden_size = 256
output_size = 10

model = nn.Sequential(
    nn.Linear(input_size, hidden_size),
    nn.ReLU(),
    nn.Linear(hidden_size, output_size),
    nn.LogSoftmax(dim=1)
    )
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)

In [None]:
# Training
epochs = 18
time_start = time()
losses = []

for epoch in range(epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images = images.view(images.shape[0], -1)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    losses.append(epoch_loss)
    print("Epoch {}: Loss {}".format(epoch, epoch_loss))

elapsed = (time() - time_start) / 60
print("\nTraining Time (in minutes): ", elapsed)

plt.plot(losses, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs Epoch')
plt.legend()
plt.show()


In [None]:
# Testing
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_dataset:
        img = images.view(1, 784)
        log_ps = model(img)
        ps = torch.exp(log_ps)
        _, predicted = torch.max(ps, 1)
        total += 1
        correct += (predicted == labels).sum().item()

print("Number Of Images Tested: ", total)
print("Model Accuracy: ", (correct / total))

torch.save(model, './mnist_model.pt')