# Evaluating a trained network on the test set 

### In this notebook, you are going to train a one layer net on the MNIST training set, then you are going to write a code that test the accuracy of the trained network on the test set. Recall that the test set contains 10,000 pictures. You will see that a trained one-layer network classify approximately 8,800 pictures correctly out of these 10,000 pictures. So it is 88% accurate. 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from random import randint
import utils

### Download the TRAINING SET

In [2]:
train_data=torch.load('../../data/mnist/train_data.pt')

print(train_data.size())

torch.Size([60000, 28, 28])


In [3]:
train_label=torch.load('../../data/mnist/train_label.pt')

print(train_label.size())

torch.Size([60000])


### Make a one layer net class

In [9]:
class one_layer_net(nn.Module):

    def __init__(self, input_size, output_size):
        super().__init__()
        self.linear_layer = nn.Linear( input_size, output_size , bias=False)
        
    def forward(self, x):
        x = self.linear_layer(x)
        p = F.softmax(x, dim=1)
        return p

### Build the net

In [10]:
net = one_layer_net(784,10)
print(net)

one_layer_net(
  (linear_layer): Linear(in_features=784, out_features=10, bias=False)
)


### Train the network (only 5000 iterations) on the train set

In [11]:
bs = 200
criterion = nn.NLLLoss()
optimizer = torch.optim.SGD(net.parameters() , lr=0.01 )

for iter in range(1,5000):
    
    indices = torch.randint( 0,60000 , size=(bs,) )
    minibatch_data = train_data[indices]
    minibatch_label = train_label[indices]
    
    inputs = minibatch_data.view(bs,784)
 
    prob = net( inputs ) 
    
    # ignore all this
    log_prob=torch.log(prob)
    loss = criterion(log_prob, minibatch_label)    
    optimizer.zero_grad()       
    loss.backward()
    optimizer.step()

### Download the test set (both data and label)

In [12]:
test_data=torch.load('../../data/mnist/test_data.pt')

print(test_data.size())

torch.Size([10000, 28, 28])


In [13]:
test_label=torch.load('../../data/mnist/test_label.pt')

print(test_label.size())

torch.Size([10000])


### Write a code that visit each picture of the test set (starting from picture index 0 up to picture index 9999), feed it to the network, identify which category get the highest percentage, then check if this category is the correct category. If it is the case we say that the network has correctly classified the picture. Your code should count how many pictures are correctly classified.

### In order to do this you will need to use the function argmax. In the example below, the maximal value of the vector x is x[3]. The argmax function return a zero-dimensional tensor that contains 3. If you want a python number you need to use item()

In [14]:
x = torch.tensor([1.0, 3.5, -2.0, 4.9, 0.8])
print(x)

tensor([ 1.0000,  3.5000, -2.0000,  4.9000,  0.8000])


In [19]:
idx = torch.argmax(x)
print(idx)
print(idx.item())

tensor(3)
3


### Complete the code below that compute the number of images in the test set that are classified corectly

In [22]:
num_correct = 0

for i in range(10000):
    prob = net(test_data[i].view(1,784))
    i_max = torch.argmax(prob)
    if i_max.item() == test_label[i].item():
        num_correct += 1
    
    
    
print(num_correct) 

9005
