# Lab 9.4: About Xavier init with MNIST Classifier

Edited By Steve Ive

Reference from

https://github.com/deeplearningzerotoall/PyTorch/blob/master/lab-09_3_mnist_nn_xavier.ipynb

## Imports

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(1)
torch.manual_seed(1)

if device == 'cuda':
    torch.cuda.manual_seed_all(1)

## Load MNIST Data

In [None]:
mnist_train = datasets.MNIST(root = 'MNIST_data/',
                             download=True,
                             transform=transforms.ToTensor(),
                             train=True)
mnist_test = datasets.MNIST(root = 'MNIST_data/',
                            download=True,
                            transform = transforms.ToTensor(),
                            train = False)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Set Hyperparameters

In [5]:
training_epochs = 15
learning_rate = 0.001
batch_size = 100

In [6]:
data_loader = torch.utils.data.DataLoader(dataset = mnist_train, shuffle= True, drop_last = True, batch_size = batch_size)

## Define Model

In [16]:
class Xavier_MNIST_Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.sq = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, 10),
        )
        self.weightInitializer()

    def forward(self, x):
        return self.sq(x)

    def weightInitializer(self):
        for index, layer in enumerate(self.sq):
            if index != 1 and index != 3:
                nn.init.xavier_uniform_(layer.weight)

In [17]:
model = Xavier_MNIST_Classifier().to(device)

In [18]:
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

## Train Model

In [19]:
total_batch = len(data_loader)

for epoch in range(training_epochs):

    avg_cost = 0
    
    for X, Y in data_loader:

        X = X.view(-1, 28 * 28).to(device)
        Y = Y.to(device)

        #prediction
        pred = model(X)

        #cost
        cost = F.cross_entropy(pred, Y).to(device)

        #Reduce cost
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        avg_cost += cost
    
    avg_cost = avg_cost / total_batch

    print('Epoch: {:d} / 15, cost: {:.6f}'.format(epoch, cost.item()))

Epoch: 0 / 15, cost: 0.141852
Epoch: 1 / 15, cost: 0.089611
Epoch: 2 / 15, cost: 0.038303
Epoch: 3 / 15, cost: 0.173185
Epoch: 4 / 15, cost: 0.012622
Epoch: 5 / 15, cost: 0.061407
Epoch: 6 / 15, cost: 0.070664
Epoch: 7 / 15, cost: 0.010458
Epoch: 8 / 15, cost: 0.004857
Epoch: 9 / 15, cost: 0.007333
Epoch: 10 / 15, cost: 0.002014
Epoch: 11 / 15, cost: 0.004299
Epoch: 12 / 15, cost: 0.051862
Epoch: 13 / 15, cost: 0.053505
Epoch: 14 / 15, cost: 0.020849


In [20]:
#Accuracy test and check prediction

with torch.no_grad():

    X_test = mnist_test.data.view(-1, 28 * 28).float().to(device)
    Y_test = mnist_test.targets.to(device)

    r = random.randint(0, len(mnist_test) - 1)

    pred = model(X_test)
    correct_prediction = torch.argmax(pred, 1)
    accuracy = (correct_prediction == Y_test).float().mean()

    print('Accuracy: ', accuracy.item())

    X_single_prediction = X_test[r]
    Y_single_prediction = Y_test[r]

    print('Label: ', Y_single_prediction.item())
    print('Prediction: ', torch.argmax(model(X_single_prediction)).item())
    

Accuracy:  0.9787999987602234
Label:  5
Prediction:  5
