# Lab-09-4 Batch Normalization

In [1]:
import torch
import torchvision.datasets as dsets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import random

In [2]:
mnist_train = dsets.MNIST(root="MNIST_data/", train=True, 
    transform=transforms.ToTensor(), download=False)
mnist_test = dsets.MNIST(root="MNIST_data/", train=False, 
    transform=transforms.ToTensor(), download=False)
data_loader = torch.utils.data.DataLoader(mnist_train,
                                          batch_size=64,
                                          shuffle=True,
                                          drop_last=True
                                         )

In [3]:
learning_rate = 1e-2
n_epochs = 10
n_batches = len(data_loader)

### NN model

In [4]:
nn_linear1 = torch.nn.Linear(784, 64, bias=True)
nn_linear2 = torch.nn.Linear(64, 32, bias=True)
nn_linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()

nn_model = torch.nn.Sequential(nn_linear1, relu, 
                               nn_linear2, relu, 
                               nn_linear3)

### Batch Normalized NN model

In [5]:
linear1 = torch.nn.Linear(784, 32, bias=True)
linear2 = torch.nn.Linear(32, 32, bias=True)
linear3 = torch.nn.Linear(32, 10, bias=True)
relu = torch.nn.ReLU()
bn1 = torch.nn.BatchNorm1d(32)
bn2 = torch.nn.BatchNorm1d(32)

bn_model = torch.nn.Sequential(linear1, bn1, relu, 
                               linear2, bn2, relu, 
                               linear3)

In [6]:
criterion = torch.nn.CrossEntropyLoss()
bn_optimizer = torch.optim.Adam(bn_model.parameters(), lr=learning_rate)
nn_optimizer = torch.optim.Adam(nn_model.parameters(), lr=learning_rate)

In [7]:
for epoch in range(n_epochs):
    bn_model.train()
    nn_model.train()
    
    for X, Y in data_loader:
        X = X.view(-1, 784)
        bn_optimizer.zero_grad()
        bn_output = bn_model(X)
        bn_loss = criterion(bn_output, Y)
        bn_loss.backward()
        bn_optimizer.step()
        
        nn_optimizer.zero_grad()
        nn_output = nn_model(X)
        nn_loss = criterion(nn_output, Y)
        nn_loss.backward()
        nn_optimizer.step()
        
    print("Epoch: {} BN Model Loss: {}".format(epoch+1, bn_loss))
    print("Epoch: {} NN Model Loss: {}".format(epoch+1, nn_loss))

Epoch: 1 BN Model Loss: 0.18945765495300293
Epoch: 1 NN Model Loss: 0.22904567420482635
Epoch: 2 BN Model Loss: 0.13865798711776733
Epoch: 2 NN Model Loss: 0.07300636917352676
Epoch: 3 BN Model Loss: 0.06534912437200546
Epoch: 3 NN Model Loss: 0.1088194027543068
Epoch: 4 BN Model Loss: 0.0756344124674797
Epoch: 4 NN Model Loss: 0.10684674978256226
Epoch: 5 BN Model Loss: 0.028770361095666885
Epoch: 5 NN Model Loss: 0.03032577410340309
Epoch: 6 BN Model Loss: 0.033264074474573135
Epoch: 6 NN Model Loss: 0.022927677258849144
Epoch: 7 BN Model Loss: 0.08727558702230453
Epoch: 7 NN Model Loss: 0.07747096568346024
Epoch: 8 BN Model Loss: 0.024136528372764587
Epoch: 8 NN Model Loss: 0.05801481753587723
Epoch: 9 BN Model Loss: 0.02030828967690468
Epoch: 9 NN Model Loss: 0.0010354926344007254
Epoch: 10 BN Model Loss: 0.029558856040239334
Epoch: 10 NN Model Loss: 0.01727069728076458


In [8]:
with torch.no_grad():
    bn_model.eval()
    nn_model.eval()
    X_test = mnist_test.test_data.view(-1, 784).float()
    Y_test = mnist_test.test_labels
    
    bn_pred = bn_model(X_test)
    bn_correct_pred = torch.argmax(bn_pred, 1) == Y_test
    bn_accuracy = bn_correct_pred.float().mean()
    
    nn_pred = nn_model(X_test)
    nn_correct_pred = torch.argmax(nn_pred, 1) == Y_test
    nn_accuracy = nn_correct_pred.float().mean()
    
    print("BN Model Accuracy: {}".format(bn_accuracy.item()))
    print("NN Model Accuracy: {}".format(nn_accuracy.item()))

BN Model Accuracy: 0.78329998254776
NN Model Accuracy: 0.9675999879837036


