# MNIST MULTI-LAYER -- DEMO

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import randint
import time
import utils

### Download the data

In [2]:
train_data=torch.load('../../data/mnist/train_data.pt')
train_label=torch.load('../../data/mnist/train_label.pt')
test_data=torch.load('../../data/mnist/test_data.pt')
test_label=torch.load('../../data/mnist/test_label.pt')

### Make a two layer net class. 

In [3]:
class two_layer_net(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.linear_layer1 = nn.Linear( input_size, hidden_size , bias=False)
        self.linear_layer2 = nn.Linear( hidden_size, output_size , bias=False)        
        
    def forward(self, x):
        x = self.linear_layer1(x)
        x = F.relu(x)
        x = self.linear_layer2(x)
        p = F.softmax(x, dim=0)
        return p

### Build the net (recall that a one layer net had 7,840 parameters)

In [4]:
net=two_layer_net(784,50,10)

print(net)

two_layer_net(
  (linear_layer1): Linear(in_features=784, out_features=50, bias=False)
  (linear_layer2): Linear(in_features=50, out_features=10, bias=False)
)


### Choose the criterion, optimizer, batchsize, learning rate

In [5]:
criterion = nn.NLLLoss()

optimizer=torch.optim.SGD( net.parameters() , lr=0.01 )

bs=20

### Evaluate on test set

In [6]:
def eval_on_test_set():

    running_error=0
    num_batches=0
    
    with torch.no_grad():

        for i in range(0,10000,bs):

            minibatch_data =  test_data[i:i+bs]
            minibatch_label= test_label[i:i+bs]

            inputs = minibatch_data.view(bs,784)

            probs=net( inputs ) 

            error = utils.get_error( probs , minibatch_label)

            running_error += error.item()

            num_batches+=1


    total_error = running_error/num_batches
    print( 'test error  = ', total_error*100 ,'percent')

### Training loop

In [7]:
start = time.time()

for epoch in range(50):
    
    running_loss=0
    running_error=0
    num_batches=0
    
    shuffled_indices=torch.randperm(60000)
 
    for count in range(0,60000,bs):
        
        # forward and backward pass
    
        optimizer.zero_grad()
        
        indices=shuffled_indices[count:count+bs]
        minibatch_data =  train_data[indices]
        minibatch_label= train_label[indices]

        inputs = minibatch_data.view(bs,784)

        inputs.requires_grad_()

        prob=net( inputs ) 

        log_prob=torch.log(prob)
        loss = criterion(log_prob, minibatch_label)
        
        loss.backward()

        optimizer.step()
        
        
        # compute some stats
        
        num_batches+=1
        
        with torch.no_grad():
            
            running_loss += loss.item()

            error = utils.get_error( prob , minibatch_label)
            running_error += error.item() 
    
    
    # once the epoch is finished we divide the "running quantities"
    # by the number of batches
    
    total_loss = running_loss/num_batches
    total_error = running_error/num_batches
    elapsed_time = time.time() - start
    
    # every 10 epoch we display the stats 
    # and compute the error rate on the test set  
    
    if epoch % 1 == 0 : 
    
        print(' ')
        
        print('epoch=',epoch, '\t time=', elapsed_time,
              '\t loss=', total_loss , '\t error=', total_error*100 ,'percent')
        
        eval_on_test_set()
               

 
epoch= 0 	 time= 1.4313218593597412 	 loss= 1.7128862002293268 	 error= 21.973333209753036 percent
test error  =  13.910000216960908 percent
 
epoch= 1 	 time= 2.8465778827667236 	 loss= 1.419286167383194 	 error= 16.101666647195817 percent
test error  =  12.570000219345093 percent
 
epoch= 2 	 time= 4.35388970375061 	 loss= 1.3465302079518635 	 error= 14.29166672229767 percent
test error  =  11.770000290870666 percent
 
epoch= 3 	 time= 5.836386680603027 	 loss= 1.3025104666550955 	 error= 13.275000131130218 percent
test error  =  11.230000329017638 percent
 
epoch= 4 	 time= 7.225856065750122 	 loss= 1.2605493220686912 	 error= 12.420000143845876 percent
test error  =  10.59000037908554 percent
 
epoch= 5 	 time= 8.611780881881714 	 loss= 1.2355794723828633 	 error= 11.855000311136246 percent
test error  =  10.140000271797179 percent
 
epoch= 6 	 time= 10.257648944854736 	 loss= 1.212474083761374 	 error= 11.241666994492212 percent
test error  =  9.700000369548798 percent
 
epoch= 

### Choose image at random from the test set and see how good/bad are the predictions

In [9]:
torch.save( net.state_dict() , 'trained_weights.pt'  )