# FINAL VERSION OF OUR CODE -- DEMO

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import randint
import time
import utils

### Download the data

In [None]:
train_data=torch.load('../../data/mnist/train_data.pt')
train_label=torch.load('../../data/mnist/train_label.pt')
test_data=torch.load('../../data/mnist/test_data.pt')
test_label=torch.load('../../data/mnist/test_label.pt')

### Make a one layer net class. 

In [None]:
class one_layer_net(nn.Module):

    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_layer = nn.Linear( input_size, output_size , bias=False)
        
    def forward(self, x):
        scores = self.linear_layer(x)
        return scores

### Build the net

In [None]:
net=one_layer_net(784,10)

print(net)
utils.display_num_param(net)

### Choose the criterion, optimizer, batchsize, learning rate

In [None]:
criterion = nn.CrossEntropyLoss()

optimizer=torch.optim.SGD( net.parameters() , lr=0.01 )

bs=200

### Evaluate on test set

In [None]:
def eval_on_test_set():

    running_error=0
    num_batches=0
    
    with torch.no_grad():

        for i in range(0,10000,bs):

            minibatch_data =  test_data[i:i+bs]
            minibatch_label= test_label[i:i+bs]

            inputs = minibatch_data.view(bs,784)

            scores=net( inputs ) 

            error = utils.get_error( scores , minibatch_label)

            running_error += error.item()

            num_batches+=1


    total_error = running_error/num_batches
    print( 'test error  = ', total_error*100 ,'percent')

### Training loop

In [None]:
start = time.time()

for epoch in range(100):
    
    running_loss=0
    running_error=0
    num_batches=0
    
    shuffled_indices=torch.randperm(60000)
 
    for count in range(0,60000,bs):
        
        # forward and backward pass
    
        optimizer.zero_grad()
        
        indices=shuffled_indices[count:count+bs]
        minibatch_data =  train_data[indices]
        minibatch_label= train_label[indices]

        inputs = minibatch_data.view(bs,784)

        inputs.requires_grad_()

        scores=net( inputs ) 

        loss =  criterion( scores , minibatch_label) 
        
        loss.backward()

        optimizer.step()
        
        
        # compute some stats
        
        num_batches+=1
        
        with torch.no_grad():
            
            running_loss += loss.item()

            error = utils.get_error( scores , minibatch_label)
            running_error += error.item() 
    
    
    # once the epoch is finished we divide the "running quantities"
    # by the number of batches
    
    total_loss = running_loss/num_batches
    total_error = running_error/num_batches
    elapsed_time = time.time() - start
    
    # every 10 epoch we display the stats 
    # and compute the error rate on the test set  
    
    if epoch % 10 == 0 : 
    
        print(' ')
        
        print('epoch=',epoch, '\t time=', elapsed_time,
              '\t loss=', total_loss , '\t error=', total_error*100 ,'percent')
        
        eval_on_test_set()
               

### Choose image at random from the test set and see how good/bad are the predictions

In [None]:
# choose a picture at random
idx=randint(0, 10000-1)
im=test_data[idx]

# diplay the picture
utils.show(im)

# feed it to the net and display the confidence scores
scores =  net( im.view(1,784)) 
probs= F.softmax(scores, dim=1)
utils.show_prob_mnist(probs)