# Full implementation using the MNIST dataset.

In [1]:
%matplotlib inline

# Custom utility class
from utils import *

# pytorch imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable

# Special package provided by pytorch
import torchvision
import torchvision.transforms as transforms

In [2]:
## Hyperparameters.

# If CUDA is enabled, use the GPU, otherwise use the CPU.
has_gpu = torch.cuda.is_available()

# image channel 3=RGB, 1=Grayscale
img_channels = 1

# Class labels.
classes = CLASS_LABELS['mnist']
num_classes = len(classes)

# Data directory.
data_dir = '../datasets/mnist'  # Dataset directory.
download = True                 # Download dataset iff not already downloaded.
normalize = 0.5                 # Normalize dataset.

# Training parameters
batch_size = 16  # Mini-batch size.
lr = 1e-2        # Optimizer's learning rate.
epochs = 5       # Number of full passes over entire dataset.

In [3]:
# Should normalize images or not.
# Normalization helps convergence.
if normalize:
    # Transform rule: Convert to Tensor, Normalize images in range -1 to 1.
    transform = transforms.Compose([transforms.ToTensor(), 
                                    transforms.Normalize((0.5, 0.5, 0.5), 
                                                         (0.5, 0.5, 0.5))])
else:
    # Transform rule: Convert to Tensor without normalizing image
    transform = transforms.Compose([transforms.ToTensor()])

# Download the training set and apply the transform rule to each.
trainset = torchvision.datasets.MNIST(root=data_dir, train=True, download=download, transform=transform)
# Load the training set into mini-batches and shuffle them
trainset = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

# Download the testing set and apply the transform rule to each.
testset = torchvision.datasets.MNIST(root=data_dir, train=False, download=download, transform=transform)
# Load the testing set into mini-batches without shuffling.
testset = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)