In [None]:
# # The MNIST datasets are hosted on yann.lecun.com that has moved under CloudFlare protection
# Run this script to enable the datasets download
# Reference: https://github.com/pytorch/vision/issues/1938

from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

import torch
from torch import nn, optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from utils import view_classify

# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                              ])
# Download and load the training data
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Build a feed-forward network
model = nn.Sequential(nn.Linear(784, 128),
                      nn.ReLU(),
                      nn.Linear(128, 64),
                      nn.ReLU(),
                      nn.Linear(64, 10),
                      nn.LogSoftmax(dim=1))

# Define the loss function
criterion = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.003)
optimizer.zero_grad()

# Training the network
epochs = 10
for j in range(epochs):
    running_loss = 0
    n_wrong = 0
    n_pred = 0
    wts = []
    wts_grad = []
    for images, labels in trainloader:
        images = images.view(images.shape[0], -1)
        optimizer.zero_grad() # clear gradient in each step
        logits = model.forward(images)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        wts.append(model[0].weight)
        wts_grad.append(model[0].weight.grad)
        running_loss += loss.item()
        pred = torch.argmax(F.softmax(logits, dim=1), dim=1)
        n_wrong += torch.sum(pred != labels)
        n_pred += logits.shape[0]
    else:
        print(f"Training loss: {running_loss/len(trainloader)}")
        print(f"Wrong predictions percentage: {n_wrong/n_pred}")