In [None]:
# Sourced from: https://www.python-engineer.com/courses/pytorchbeginner/13-feedforward-neural-network/

In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import math

In [2]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# Hyper-parameters 
input_size = 784 # 28x28
hidden_size = 500 
num_classes = 10
num_epochs = 200
batch_size = 100
learning_rate = 0.001

# TODO adjust these as per the configuration in the paper.
# Especially how learning rate changes


In [4]:
# TYPICAL PIPELINE

# 1) LOAD DATA
# 2) DESIGN MODEL
# 3) LOSS AND OPTIMIZER
# 4) TRAINING LOOP
#        - FORWARD PASS: computer prediction and calculate error
#        - BACKWARD PASS: calculate gradients
#        - UPDATE WEIGHTS
#        - SET ZERO GRAD IF USING AUTOGRAD

# 5) TEST MODEL WITH TEST DATA


In [5]:
# 1) LOAD DATA

# MNIST dataset 
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

In [6]:
# Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=batch_size, 
                                          shuffle=False)


#examples = iter(test_loader)
#example_data, example_targets = examples.next()

#for i in range(9):
#    plt.subplot(3,3,i+1)
#    plt.imshow(example_data[i][0], cmap='gray')
#plt.show()


In [7]:
# 2) DESIGN MODEL

# Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(NeuralNet, self).__init__()
        self.input_size = input_size
        self.l1 = nn.Linear(input_size, hidden_size) 
        self.relu = nn.ReLU()
        self.l2 = nn.Linear(hidden_size, num_classes)  
    
    def forward(self, x):
        out = self.l1(x)
        out = self.relu(out)
        out = self.l2(out)
        # no activation and no softmax at the end
        return out

model = NeuralNet(input_size, hidden_size, num_classes).to(device)

In [8]:
# 3) LOSS AND OPTIMIZER

# TODO: Use the custom loss function in th paper. Equation [12]. It seems to use one-hot encoded labels.
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) 
print(optimizer.param_groups)

# TODO: To dynamically change learning rate
for param_group in optimizer.param_groups:
    print("param_group[lr]:", param_group['lr'])
#    param_group['lr'] = lr

[{'params': [Parameter containing:
tensor([[-0.0087,  0.0287,  0.0297,  ...,  0.0263,  0.0336,  0.0272],
        [-0.0288,  0.0178, -0.0305,  ...,  0.0109,  0.0225,  0.0176],
        [-0.0059, -0.0055,  0.0299,  ..., -0.0139,  0.0298,  0.0131],
        ...,
        [ 0.0247,  0.0092, -0.0346,  ..., -0.0223, -0.0058, -0.0100],
        [-0.0284, -0.0182,  0.0313,  ...,  0.0176, -0.0336, -0.0087],
        [-0.0003, -0.0265,  0.0338,  ..., -0.0208,  0.0354,  0.0350]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-1.0755e-02, -8.3936e-03,  3.0748e-02,  3.0286e-02, -3.0292e-02,
         3.0231e-02, -3.1381e-02,  2.7698e-02, -9.6403e-03, -1.5125e-02,
         9.9192e-03, -2.1362e-02,  4.4716e-03,  3.5070e-02, -2.3002e-02,
         1.7972e-03,  3.3925e-02,  2.8830e-02, -3.5374e-02, -3.0989e-02,
        -3.0631e-02,  8.6100e-03, -2.4656e-02, -2.5020e-02, -2.1210e-02,
         3.4999e-02,  5.3167e-03,  2.7774e-03, -1.7985e-02,  2.1601e-02,
         3.5485e-02,  2.71

In [9]:
# 4) TRAINING LOOP

# Train the model
n_total_steps = len(train_loader)
print(n_total_steps)

total_samples = len(train_dataset)
n_iterations = math.ceil(total_samples/batch_size)
print(total_samples, n_iterations)

'''
# Dummy training loop for printing 
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.reshape(-1, 28*28)
        
        # here: 60000 samples, batch_size = 100, n_iters=60000/100 = 600 iterations
        # Run your training process
        if (i+1) % 5 == 0:
            print(f'Epoch: {epoch+1}/{num_epochs}, Step {i+1}/{n_iterations}| Inputs {images.shape} | Labels {labels.shape}')
'''

# Real training loop
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):  
        # original shape: [100, 1, 28, 28]
        # resized: [100, 784]
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 100 == 0:
            print (f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')

# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)

with torch.no_grad():
    n_correct = 0
    n_samples = 0
    for images, labels in test_loader:
        # original shape: [100, 1, 28, 28]
        # resized: [100, 784]
        images = images.reshape(-1, 28*28).to(device)
        labels = labels.to(device)
        
        outputs = model(images)
        
        # max returns (value ,index)
        _, predicted = torch.max(outputs.data, 1)
        n_samples += labels.size(0)
        n_correct += (predicted == labels).sum().item()

    acc = 100.0 * n_correct / n_samples
    print(f'Accuracy of the network on the 10000 test images: {acc} %')

600
60000 600
Epoch [1/200], Step [100/600], Loss: 0.2730
Epoch [1/200], Step [200/600], Loss: 0.2155
Epoch [1/200], Step [300/600], Loss: 0.1878
Epoch [1/200], Step [400/600], Loss: 0.0769
Epoch [1/200], Step [500/600], Loss: 0.1746
Epoch [1/200], Step [600/600], Loss: 0.1515
Epoch [2/200], Step [100/600], Loss: 0.0499
Epoch [2/200], Step [200/600], Loss: 0.1281
Epoch [2/200], Step [300/600], Loss: 0.0625
Epoch [2/200], Step [400/600], Loss: 0.0453
Epoch [2/200], Step [500/600], Loss: 0.1371
Epoch [2/200], Step [600/600], Loss: 0.1116
Epoch [3/200], Step [100/600], Loss: 0.0315
Epoch [3/200], Step [200/600], Loss: 0.0809
Epoch [3/200], Step [300/600], Loss: 0.0237
Epoch [3/200], Step [400/600], Loss: 0.0370
Epoch [3/200], Step [500/600], Loss: 0.0587
Epoch [3/200], Step [600/600], Loss: 0.1005
Epoch [4/200], Step [100/600], Loss: 0.0919
Epoch [4/200], Step [200/600], Loss: 0.0882
Epoch [4/200], Step [300/600], Loss: 0.0214
Epoch [4/200], Step [400/600], Loss: 0.0561
Epoch [4/200], Ste

Epoch [31/200], Step [400/600], Loss: 0.0000
Epoch [31/200], Step [500/600], Loss: 0.0000
Epoch [31/200], Step [600/600], Loss: 0.0001
Epoch [32/200], Step [100/600], Loss: 0.0001
Epoch [32/200], Step [200/600], Loss: 0.0001
Epoch [32/200], Step [300/600], Loss: 0.0000
Epoch [32/200], Step [400/600], Loss: 0.0000
Epoch [32/200], Step [500/600], Loss: 0.0000
Epoch [32/200], Step [600/600], Loss: 0.0000
Epoch [33/200], Step [100/600], Loss: 0.0000
Epoch [33/200], Step [200/600], Loss: 0.0000
Epoch [33/200], Step [300/600], Loss: 0.0000
Epoch [33/200], Step [400/600], Loss: 0.0000
Epoch [33/200], Step [500/600], Loss: 0.0000
Epoch [33/200], Step [600/600], Loss: 0.0000
Epoch [34/200], Step [100/600], Loss: 0.0000
Epoch [34/200], Step [200/600], Loss: 0.0000
Epoch [34/200], Step [300/600], Loss: 0.0000
Epoch [34/200], Step [400/600], Loss: 0.0000
Epoch [34/200], Step [500/600], Loss: 0.0000
Epoch [34/200], Step [600/600], Loss: 0.0000
Epoch [35/200], Step [100/600], Loss: 0.0000
Epoch [35/

Epoch [62/200], Step [100/600], Loss: 0.0000
Epoch [62/200], Step [200/600], Loss: 0.0000
Epoch [62/200], Step [300/600], Loss: 0.0000
Epoch [62/200], Step [400/600], Loss: 0.0000
Epoch [62/200], Step [500/600], Loss: 0.0000
Epoch [62/200], Step [600/600], Loss: 0.0000
Epoch [63/200], Step [100/600], Loss: 0.0000
Epoch [63/200], Step [200/600], Loss: 0.0000
Epoch [63/200], Step [300/600], Loss: 0.0000
Epoch [63/200], Step [400/600], Loss: 0.0000
Epoch [63/200], Step [500/600], Loss: 0.0000
Epoch [63/200], Step [600/600], Loss: 0.0000
Epoch [64/200], Step [100/600], Loss: 0.0000
Epoch [64/200], Step [200/600], Loss: 0.0000
Epoch [64/200], Step [300/600], Loss: 0.0000
Epoch [64/200], Step [400/600], Loss: 0.0000
Epoch [64/200], Step [500/600], Loss: 0.0000
Epoch [64/200], Step [600/600], Loss: 0.0000
Epoch [65/200], Step [100/600], Loss: 0.0000
Epoch [65/200], Step [200/600], Loss: 0.0000
Epoch [65/200], Step [300/600], Loss: 0.0000
Epoch [65/200], Step [400/600], Loss: 0.0000
Epoch [65/

Epoch [92/200], Step [400/600], Loss: 0.0000
Epoch [92/200], Step [500/600], Loss: 0.0000
Epoch [92/200], Step [600/600], Loss: 0.0000
Epoch [93/200], Step [100/600], Loss: 0.0000
Epoch [93/200], Step [200/600], Loss: 0.0000
Epoch [93/200], Step [300/600], Loss: 0.0000
Epoch [93/200], Step [400/600], Loss: 0.0000
Epoch [93/200], Step [500/600], Loss: 0.0000
Epoch [93/200], Step [600/600], Loss: 0.0000
Epoch [94/200], Step [100/600], Loss: 0.0000
Epoch [94/200], Step [200/600], Loss: 0.0000
Epoch [94/200], Step [300/600], Loss: 0.0000
Epoch [94/200], Step [400/600], Loss: 0.0000
Epoch [94/200], Step [500/600], Loss: 0.0000
Epoch [94/200], Step [600/600], Loss: 0.0000
Epoch [95/200], Step [100/600], Loss: 0.0000
Epoch [95/200], Step [200/600], Loss: 0.0000
Epoch [95/200], Step [300/600], Loss: 0.0000
Epoch [95/200], Step [400/600], Loss: 0.0000
Epoch [95/200], Step [500/600], Loss: 0.0000
Epoch [95/200], Step [600/600], Loss: 0.0000
Epoch [96/200], Step [100/600], Loss: 0.0000
Epoch [96/

Epoch [122/200], Step [400/600], Loss: 0.0000
Epoch [122/200], Step [500/600], Loss: 0.0000
Epoch [122/200], Step [600/600], Loss: 0.0000
Epoch [123/200], Step [100/600], Loss: 0.0000
Epoch [123/200], Step [200/600], Loss: 0.0000
Epoch [123/200], Step [300/600], Loss: 0.0000
Epoch [123/200], Step [400/600], Loss: 0.0000
Epoch [123/200], Step [500/600], Loss: 0.0000
Epoch [123/200], Step [600/600], Loss: 0.0000
Epoch [124/200], Step [100/600], Loss: 0.0000
Epoch [124/200], Step [200/600], Loss: 0.0000
Epoch [124/200], Step [300/600], Loss: 0.0000
Epoch [124/200], Step [400/600], Loss: 0.0000
Epoch [124/200], Step [500/600], Loss: 0.0000
Epoch [124/200], Step [600/600], Loss: 0.0000
Epoch [125/200], Step [100/600], Loss: 0.0000
Epoch [125/200], Step [200/600], Loss: 0.0000
Epoch [125/200], Step [300/600], Loss: 0.0000
Epoch [125/200], Step [400/600], Loss: 0.0313
Epoch [125/200], Step [500/600], Loss: 0.0012
Epoch [125/200], Step [600/600], Loss: 0.0480
Epoch [126/200], Step [100/600], L

Epoch [152/200], Step [300/600], Loss: 0.0000
Epoch [152/200], Step [400/600], Loss: 0.0000
Epoch [152/200], Step [500/600], Loss: 0.0000
Epoch [152/200], Step [600/600], Loss: 0.0000
Epoch [153/200], Step [100/600], Loss: 0.0000
Epoch [153/200], Step [200/600], Loss: 0.0000
Epoch [153/200], Step [300/600], Loss: 0.0000
Epoch [153/200], Step [400/600], Loss: 0.0000
Epoch [153/200], Step [500/600], Loss: 0.0000
Epoch [153/200], Step [600/600], Loss: 0.0000
Epoch [154/200], Step [100/600], Loss: 0.0000
Epoch [154/200], Step [200/600], Loss: 0.0000
Epoch [154/200], Step [300/600], Loss: 0.0000
Epoch [154/200], Step [400/600], Loss: 0.0000
Epoch [154/200], Step [500/600], Loss: 0.0000
Epoch [154/200], Step [600/600], Loss: 0.0000
Epoch [155/200], Step [100/600], Loss: 0.0000
Epoch [155/200], Step [200/600], Loss: 0.0000
Epoch [155/200], Step [300/600], Loss: 0.0000
Epoch [155/200], Step [400/600], Loss: 0.0000
Epoch [155/200], Step [500/600], Loss: 0.0000
Epoch [155/200], Step [600/600], L

Epoch [182/200], Step [200/600], Loss: 0.0000
Epoch [182/200], Step [300/600], Loss: 0.0000
Epoch [182/200], Step [400/600], Loss: 0.0000
Epoch [182/200], Step [500/600], Loss: 0.0000
Epoch [182/200], Step [600/600], Loss: 0.0000
Epoch [183/200], Step [100/600], Loss: 0.0000
Epoch [183/200], Step [200/600], Loss: 0.0000
Epoch [183/200], Step [300/600], Loss: 0.0000
Epoch [183/200], Step [400/600], Loss: 0.0000
Epoch [183/200], Step [500/600], Loss: 0.0000
Epoch [183/200], Step [600/600], Loss: 0.0000
Epoch [184/200], Step [100/600], Loss: 0.0000
Epoch [184/200], Step [200/600], Loss: 0.0000
Epoch [184/200], Step [300/600], Loss: 0.0000
Epoch [184/200], Step [400/600], Loss: 0.0000
Epoch [184/200], Step [500/600], Loss: 0.0000
Epoch [184/200], Step [600/600], Loss: 0.0000
Epoch [185/200], Step [100/600], Loss: 0.0000
Epoch [185/200], Step [200/600], Loss: 0.0000
Epoch [185/200], Step [300/600], Loss: 0.0000
Epoch [185/200], Step [400/600], Loss: 0.0000
Epoch [185/200], Step [500/600], L

In [None]:
# LOG

for param in model.parameters():
  print(param.data)