In [4]:
import torch
from torch.utils.data import DataLoader 
from torch.nn import init
import torch.optim as optim 
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import time 

In [7]:
def load_data_fashion_mnist(batch_size, resize = None, path = './Dataset/FashionMNIST'):
    """Download the fashion mnist dataset and then load into memory"""
    trans = []
    if resize:
        trans.append(transforms.Resize(size = resize))
    trans.append(transforms.ToTensor())
    transform = transforms.Compose(trans)
    
    mnist_train = torchvision.datasets.FashionMNIST(root = path, train = True, download = True, transform = transform) 
    mnist_test = torchvision.datasets.FashionMNIST(root = path,  train = False, download = True, transform = transform)
    
    train_iter = DataLoader(mnist_train, batch_size= batch_size, shuffle = True, num_workers = 4) 
    test_iter = DataLoader(mnist_test, batch_size = batch_size, shuffle = True, num_workers = 4)
    
    return train_iter, test_iter

In [12]:
def batch_norm(is_training, X, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not is_training:
        X_norm = (X-moving_mean)/torch.sqrt(moving_var+eps)
    else:
        assert len(X.shape) in (2, 4) 
        if len(X.shape)==2:
            mean = X.mean(dim=0,keepdim=True)
            var = ((X-mean)**2).mean(dim=0,keepdim=True)
        else:
            mean = X.mean(dim=0,keepdim=True).mean(dim=2,keepdim=True).mean(dim=3,keepdim=True)
            var = ((X-mean)**2).mean(dim=0,keepdim=True).mean(dim=2,keepdim=True).mean(dim=3,keepdim=True)
        X_norm = (X-mean)/torch.sqrt(var+eps)
        moving_mean=momentum*moving_mean + (1-momentum)*mean
        moving_var = momentum*moving_var + (1-momentum)*var
    Y = gamma*X_norm + beta
    return Y, moving_mean, moving_var

In [20]:
class Batch_norm(nn.Module):
    def __init__(self, num_features, num_dims):
        super().__init__()
        if num_dims == 2:
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    def forward(self, X):
        if self.moving_mean.device != X.device:
            self.moving_mean = self.moving_mean.to(X.device)
            self.moving_var = self.moving_var.to(X.device)
        Y, self.moving_mean, self.moving_var = batch_norm(self.training, X, self.gamma, self.beta, 
                                                          self.moving_mean, self.moving_var, 
                                                          eps=1e-5, momentum=0.9)
        return Y

In [21]:
LeNet_BN = nn.Sequential(nn.Conv2d(1,6,5),
                        Batch_norm(6,4),
                        nn.Sigmoid(),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(6,16,5),
                        Batch_norm(16,4),
                        nn.Sigmoid(),
                        nn.MaxPool2d(2,2),
                        nn.Flatten(),
                        nn.Linear(16*4*4, 120),
                        Batch_norm(120, 2),
                        nn.Sigmoid(),
                        nn.Linear(120, 84),
                        Batch_norm(84, 2),
                        nn.Sigmoid(),
                        nn.Linear(84, 10)
                        )

In [36]:
def evaluate_accuracy(data_iter, network, device = None):
    if device is None:
        device = list(network.parameters())[0].device
    with torch.no_grad():
        acc_num, n = 0.0, 0 
        for X_test, y_test in data_iter:
            X_test = X_test.to(device)
            y_test = y_test.to(device)
            network.eval()
            y_hat = network(X_test)
            acc_num  += (y_hat.argmax(dim = 1)==y_test.to(device)).sum().cpu().item()
            network.train()
            n += y_test.shape[0]
    acc = acc_num/n
    return acc 

In [37]:
def train_network(network, train_iter, test_iter, optimizer, num_epochs):
    device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
    network = network.to(device)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(num_epochs): 
        train_loss_sum, train_acc_num, batch_count, n, start_time = 0.0, 0, 0, 0, time.time() 
        for X, y in train_iter:
            X = X.to(device)
            y = y.to(device)
            y_hat = network(X)
            loss = loss_fn(y_hat, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_loss_sum += loss.cpu().item()
            train_acc_num += (y_hat.argmax(dim=1) == y).sum().cpu().item()
            n += y.shape[0]
            batch_count += 1
        test_acc = evaluate_accuracy(test_iter, network)
        print('epoch: %d, train_loss: %.3f, train_acc: %.3f, test_acc: %.3f, time: %.2f'
              %(epoch+1, train_loss_sum/batch_count, train_acc_num/n, test_acc, (time.time()-start_time)))

In [38]:
batch_size = 256
train_iter, test_iter = load_data_fashion_mnist(batch_size=batch_size)

In [39]:
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(LeNet_BN.parameters(), lr = lr)
train_network(LeNet_BN, train_iter, test_iter, optimizer, num_epochs)


epoch: 1, train_loss: 0.264, train_acc: 0.906, test_acc: 0.882, time: 3.45
epoch: 2, train_loss: 0.255, train_acc: 0.907, test_acc: 0.878, time: 3.59
epoch: 3, train_loss: 0.245, train_acc: 0.911, test_acc: 0.864, time: 3.33
epoch: 4, train_loss: 0.237, train_acc: 0.914, test_acc: 0.887, time: 3.30
epoch: 5, train_loss: 0.232, train_acc: 0.915, test_acc: 0.870, time: 3.58


In [40]:
LeNet_BN = nn.Sequential(nn.Conv2d(1,6,5),
                        nn.BatchNorm2d(6),
                        nn.Sigmoid(),
                        nn.MaxPool2d(2,2),
                        nn.Conv2d(6,16,5),
                        nn.BatchNorm2d(16),
                        nn.Sigmoid(),
                        nn.MaxPool2d(2,2),
                        nn.Flatten(),
                        nn.Linear(16*4*4, 120),
                        nn.BatchNorm1d(120),
                        nn.Sigmoid(),
                        nn.Linear(120, 84),
                        nn.BatchNorm1d(84),
                        nn.Sigmoid(),
                        nn.Linear(84, 10)
                        )

In [41]:
lr, num_epochs = 0.001, 5
optimizer = optim.Adam(LeNet_BN.parameters(), lr = lr)
train_network(LeNet_BN, train_iter, test_iter, optimizer, num_epochs)


epoch: 1, train_loss: 1.020, train_acc: 0.775, test_acc: 0.820, time: 3.54
epoch: 2, train_loss: 0.467, train_acc: 0.860, test_acc: 0.823, time: 3.42
epoch: 3, train_loss: 0.375, train_acc: 0.876, test_acc: 0.863, time: 3.30
epoch: 4, train_loss: 0.334, train_acc: 0.885, test_acc: 0.864, time: 3.53
epoch: 5, train_loss: 0.311, train_acc: 0.890, test_acc: 0.874, time: 3.40
