In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchvision.transforms as transforms 
from torchvision.datasets import MNIST

In [4]:
mnist_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, ), (1.0))
])

download_root = "./MNIST_DATASET"

train_dataset = MNIST(download_root, transform=mnist_transform, train=True, download=True)
valid_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)
test_dataset = MNIST(download_root, transform=mnist_transform, train=False, download=True)

In [5]:
batch_size = 64

train_loader = DataLoader(dataset=train_dataset, batch_size = batch_size, shuffle = True)

valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [16]:
images, labels = next(iter(train_loader))

images.shape, labels.shape, labels

(torch.Size([64, 1, 28, 28]),
 torch.Size([64]),
 tensor([8, 0, 6, 5, 0, 5, 6, 4, 1, 3, 4, 6, 8, 4, 7, 2, 7, 1, 8, 0, 9, 3, 6, 1,
         6, 1, 6, 3, 4, 4, 6, 0, 4, 2, 8, 1, 7, 7, 0, 1, 8, 1, 0, 7, 7, 2, 5, 7,
         0, 8, 2, 1, 6, 7, 3, 1, 5, 0, 5, 2, 3, 4, 3, 2]))

In [11]:
class NNModel(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.hidden_size = hidden_size

        self.layer = nn.Sequential(
            nn.Linear(28*28, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 10)
        )

    def forward(self, x):
        return self.layer(x)


class BNModel(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()

        self.hidden_size = hidden_size

        self.layer = nn.Sequential(
            nn.Linear(28*28, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 10)
        )

    def forward(self, x):
        return self.layer(x)

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cpu')

In [33]:
nn_model = NNModel(512).to(device)
bn_model = BNModel(512).to(device)

In [34]:
nb_epochs = 10
lr = 0.01

optim_nn = optim.Adam(nn_model.parameters(), lr=lr)
optim_bn = optim.Adam(bn_model.parameters(), lr=lr)

In [35]:
loss = nn.CrossEntropyLoss()

In [36]:
for epoch in range(nb_epochs):

    for i, (images, labels) in enumerate(train_loader):

        images, labels = images.to(device), labels.to(device)

        images = images.view(-1, 28*28)

        prediction = nn_model(images)

        cost = loss(prediction, labels)

        optim_nn.zero_grad()
        cost.backward()
        optim_nn.step()

        if i % 100 == 0:
            print(f"Epoch: [{i}/{epoch+1}] Loss: {cost}")

Epoch: [1/0] Loss: 2.296948194503784
Epoch: [101/0] Loss: 0.31948375701904297
Epoch: [201/0] Loss: 0.4745776057243347
Epoch: [301/0] Loss: 0.46076008677482605
Epoch: [401/0] Loss: 0.24814806878566742
Epoch: [501/0] Loss: 0.4919009804725647
Epoch: [601/0] Loss: 0.2548016905784607
Epoch: [701/0] Loss: 0.6835585236549377
Epoch: [801/0] Loss: 0.40767356753349304
Epoch: [901/0] Loss: 0.3969203233718872
Epoch: [1/1] Loss: 0.3498804569244385
Epoch: [101/1] Loss: 0.4220711588859558
Epoch: [201/1] Loss: 0.09206794947385788
Epoch: [301/1] Loss: 0.27979686856269836
Epoch: [401/1] Loss: 0.1187838613986969
Epoch: [501/1] Loss: 0.2078983634710312
Epoch: [601/1] Loss: 0.21275264024734497
Epoch: [701/1] Loss: 0.30423903465270996
Epoch: [801/1] Loss: 0.3434825837612152
Epoch: [901/1] Loss: 0.46602553129196167
Epoch: [1/2] Loss: 0.1093125194311142
Epoch: [101/2] Loss: 0.3904871940612793
Epoch: [201/2] Loss: 0.12160777300596237
Epoch: [301/2] Loss: 0.3618926703929901
Epoch: [401/2] Loss: 0.25176617503166

In [38]:
bn_model = BNModel(512).to(device)
optim_bn = optim.Adam(bn_model.parameters(), lr=0.001)

for epoch in range(nb_epochs):
    avg_cost = 0
    for i, (images, labels) in enumerate(train_loader):

        images, labels = images.to(device), labels.to(device)

        images = images.view(-1, 28*28)

        prediction = bn_model(images)

        cost = loss(prediction, labels)
        avg_cost += cost

        optim_bn.zero_grad()
        cost.backward()
        optim_bn.step()

    print(f"Epoch: [{epoch+1}/{nb_epochs}] Loss: {cost/len(train_loader)}")

Epoch: [937/1] Loss: 0.0003268908185418695
Epoch: [937/2] Loss: 0.00040221840026788414
Epoch: [937/3] Loss: 9.771439181349706e-06
Epoch: [937/4] Loss: 0.00014311130507849157
Epoch: [937/5] Loss: 7.0154414970602375e-06
Epoch: [937/6] Loss: 1.713589699647855e-05
Epoch: [937/7] Loss: 1.0815816722242744e-06
Epoch: [937/8] Loss: 3.24806896969676e-05
Epoch: [937/9] Loss: 2.8255001325305784e-06
Epoch: [937/10] Loss: 1.1898594493686687e-05


## Without BN

0.077

## With BN

0.000189