# MNIST Project

## Imports

In [1]:
import torch
from torch import Tensor 
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.autograd import Variable
import src.dlc_practical_prologue as prologue

## Loading the Data

In [2]:
def onehot_Y(target):
    '''
    Create one-hot labels for Y
    '''
    res = torch.zeros(target.size(0), 2)
    res[range(target.size(0)), target] = 1
    return res

In [3]:
def onehot_Class(target):
    res = torch.zeros(target.size(0), 10*target.size(1))
    res[range(target.size(0)), target[:, 0]] = 1
    res[range(target.size(0)), target[:, 1] + 10] = 1
    return res

In [2]:
N = 1000
train_X, train_Y, train_Class, test_X, test_Y, test_Class = prologue.generate_pair_sets(N)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Processing...
Done!


In [5]:
#train_Y = onehot_Y(train_Y).long()
#train_Class = onehot_Class(train_Class).long()
#test_Y = onehot_Y(test_Y).long()
#test_Class = onehot_Class(test_Class).long()

In [3]:
mu, std = train_X.mean(), train_X.std()
train_X.sub_(mu).div_(std)
test_X.sub_(mu).div_(std)

tensor([[[[-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          ...,
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653]],

         [[-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          ...,
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653]]],


        [[[-0.4653, -0.4653, -0.4653,  ..., -0.4653, -0.4653, -0.4653],
          [-0.4653, -0.4653,

## Model 1: Naive convnet
For the first model, we create a naive convnet, not taking into account the structrue of the channels.

In [4]:
class convNet(nn.Module):
    def __init__(self):
        super(convNet, self).__init__()
        self.conv1 = nn.Conv2d(2, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 2)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [30]:
def train_model(model, criterion, optimizer, nb_epochs, minibatch_size, train_X, train_Y, verbose=False):
    for e in range(nb_epochs):
        for b in range(0, train_X.size(0), minibatch_size):
            out = model(train_X.narrow(0, b, minibatch_size))
            loss = criterion(out, train_Y.narrow(0, b, minibatch_size))
            model.zero_grad()
            loss.backward()
            optimizer.step()
        if(verbose): print(loss)
    return model

In [32]:
def compute_nb_errors(model, data_input, data_target, minibatch_size):
    nb_data_errors = 0
    for b in range(0, data_input.size(0), minibatch_size):
        out = model(data_input.narrow(0, b, minibatch_size))
        _, pred = torch.max(out.data, 1)
        for k in range(minibatch_size):
            if data_target[b+k] != pred[k]:
                #print("NEW ERROR:")
                #print("target class: ", data_target[b+k].item())
                #print("prediction: ", pred[k].item())
                nb_data_errors += 1
    return nb_data_errors

In [12]:
model1 = convNet()
model1 = train_model(model1, nn.CrossEntropyLoss(), optim.SGD(model1.parameters(), lr=1e-1), 100, 100, train_X, train_Y, verbose=True)

tensor(0.6794, grad_fn=<NllLossBackward>)
tensor(0.6529, grad_fn=<NllLossBackward>)
tensor(0.6240, grad_fn=<NllLossBackward>)
tensor(0.5779, grad_fn=<NllLossBackward>)
tensor(0.5367, grad_fn=<NllLossBackward>)
tensor(0.4974, grad_fn=<NllLossBackward>)
tensor(0.4523, grad_fn=<NllLossBackward>)
tensor(0.4379, grad_fn=<NllLossBackward>)
tensor(0.3877, grad_fn=<NllLossBackward>)
tensor(0.4385, grad_fn=<NllLossBackward>)
tensor(0.3271, grad_fn=<NllLossBackward>)
tensor(0.4138, grad_fn=<NllLossBackward>)
tensor(0.2939, grad_fn=<NllLossBackward>)
tensor(0.2768, grad_fn=<NllLossBackward>)
tensor(0.2392, grad_fn=<NllLossBackward>)
tensor(0.2303, grad_fn=<NllLossBackward>)
tensor(0.2876, grad_fn=<NllLossBackward>)
tensor(0.2588, grad_fn=<NllLossBackward>)
tensor(0.1905, grad_fn=<NllLossBackward>)
tensor(0.2463, grad_fn=<NllLossBackward>)
tensor(0.2017, grad_fn=<NllLossBackward>)
tensor(0.1575, grad_fn=<NllLossBackward>)
tensor(0.1148, grad_fn=<NllLossBackward>)
tensor(0.1374, grad_fn=<NllLossBac

In [13]:
compute_nb_errors(model1, test_X, test_Y, 100)

190

We observe that this model does not manage to learn the mapping very well.

## Model 2: Using transfer learning

### Part 1: Study of a good functioning digit detection network for 28x28 images

In [14]:
train_input, train_target, test_input, test_target = \
    prologue.load_data(one_hot_labels = False, normalize = True, flatten = False)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.fc1 = nn.Linear(256, 200)
        self.fc2 = nn.Linear(200, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=3, stride=3))
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2))
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

* Using MNIST
** Reduce the data-set (use --full for the full thing)
** Use 1000 train and 1000 test samples


In [15]:
train_input, train_target = Variable(train_input), Variable(train_target)
model = Net()

In [16]:
model = train_model(model, nn.CrossEntropyLoss(), optim.SGD(model.parameters(), lr=1e-1), 50, 100, train_input, train_target)

In [17]:
compute_nb_errors(model, test_input, test_target, 100)

65

### Part 2: Adapting the network to 14x14 images

In [7]:
# fist, create the training and testing dataset
train_target_14px = torch.cat((train_Class[:,0], train_Class[:,1]))
train_input_14px = torch.cat((train_X[:,0,:,:].resize_(1000,1,14,14), train_X[:,1,:,:].resize_(1000,1,14,14)))

test_target_14px = torch.cat((test_Class[:,0], test_Class[:,1]))
test_input_14px = torch.cat((test_X[:,0,:,:].resize_(1000,1,14,14), test_X[:,1,:,:].resize_(1000,1,14,14)))

In [8]:
class Net_14px(nn.Module):
    def __init__(self):
        super(Net_14px, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3) 
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.fc1 = nn.Linear(256, 100)
        self.fc2 = nn.Linear(100, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), kernel_size=2, stride=2)) #image size 12x12-> image size 6x6
        x = F.relu(F.max_pool2d(self.conv2(x), kernel_size=2, stride=2)) #image size 4x4 -> image size 2x2
        x = F.relu(self.fc1(x.view(-1, 256)))
        x = self.fc2(x)
        return x

In [26]:
model_14px = Net_14px()
model_14px = train_model(model_14px, nn.CrossEntropyLoss(), optim.SGD(model_14px.parameters(), lr=0.1), 100, 100, \
                         train_input_14px, train_target_14px, verbose=True)

tensor(2.3097, grad_fn=<NllLossBackward>)
tensor(2.3035, grad_fn=<NllLossBackward>)
tensor(2.2999, grad_fn=<NllLossBackward>)
tensor(2.2963, grad_fn=<NllLossBackward>)
tensor(2.2940, grad_fn=<NllLossBackward>)
tensor(2.2910, grad_fn=<NllLossBackward>)
tensor(2.2886, grad_fn=<NllLossBackward>)
tensor(2.2854, grad_fn=<NllLossBackward>)
tensor(2.2825, grad_fn=<NllLossBackward>)
tensor(2.2788, grad_fn=<NllLossBackward>)
tensor(2.2749, grad_fn=<NllLossBackward>)
tensor(2.2712, grad_fn=<NllLossBackward>)
tensor(2.2677, grad_fn=<NllLossBackward>)
tensor(2.2621, grad_fn=<NllLossBackward>)
tensor(2.2559, grad_fn=<NllLossBackward>)
tensor(2.2517, grad_fn=<NllLossBackward>)
tensor(2.2484, grad_fn=<NllLossBackward>)
tensor(2.2440, grad_fn=<NllLossBackward>)
tensor(2.2402, grad_fn=<NllLossBackward>)
tensor(2.2340, grad_fn=<NllLossBackward>)
tensor(2.2306, grad_fn=<NllLossBackward>)
tensor(2.2221, grad_fn=<NllLossBackward>)
tensor(2.2156, grad_fn=<NllLossBackward>)
tensor(2.2092, grad_fn=<NllLossBac

KeyboardInterrupt: 

In [27]:
compute_nb_errors(model_14px, train_input_14px, train_target_14px, 100)

NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  7
prediction:  4
NEW ERROR:
target class:  8
prediction:  9
NEW ERROR:
target class:  7
prediction:  1
NEW ERROR:
target class:  8
prediction:  9
NEW ERROR:
target class:  4
prediction:  2
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  3
prediction:  1
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  2
prediction:  3
NEW ERROR:
target class:  6
prediction:  4
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  7
prediction:  9
NEW ERROR:
target class:  1
prediction:  2
NEW ERROR:
target class:  0
prediction:  2
NEW ERROR:
target class:  4
prediction:  1
NEW ERROR:
target class:  0
prediction:  4
NEW ERROR:
target class:  7
prediction:  1
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  8
prediction:  2
NEW ERROR:
target class:  2
prediction:  6
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:


prediction:  1
NEW ERROR:
target class:  4
prediction:  9
NEW ERROR:
target class:  1
prediction:  3
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  9
prediction:  3
NEW ERROR:
target class:  1
prediction:  9
NEW ERROR:
target class:  4
prediction:  7
NEW ERROR:
target class:  4
prediction:  2
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  1
prediction:  5
NEW ERROR:
target class:  1
prediction:  9
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  3
prediction:  1
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  8
prediction:  1
NEW ERROR:
target class:  2
prediction:  1
NEW ERROR:
target class:  4
prediction:

prediction:  1
NEW ERROR:
target class:  3
prediction:  1
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  0
prediction:  6
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:
target class:  2
prediction:  5
NEW ERROR:
target class:  3
prediction:  2
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:
target class:  0
prediction:  5
NEW ERROR:
target class:  8
prediction:  6
NEW ERROR:
target class:  7
prediction:  9
NEW ERROR:
target class:  0
prediction:  5
NEW ERROR:
target class:  1
prediction:  2
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  8
prediction:  4
NEW ERROR:
target class:  3
prediction:  1
NEW ERROR:
target class:  1
prediction:  3
NEW ERROR:
target class:  7
prediction:  4
NEW ERROR:
target class:  3
prediction:  9
NEW ERROR:
target class:  5
prediction:  1
NEW ERROR:
target class:  6
prediction:  4
NEW ERROR:
target class:  6
prediction:

NEW ERROR:
target class:  3
prediction:  2
NEW ERROR:
target class:  4
prediction:  9
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  8
prediction:  1
NEW ERROR:
target class:  4
prediction:  9
NEW ERROR:
target class:  9
prediction:  2
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  8
prediction:  9
NEW ERROR:
target class:  9
prediction:  3
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  1
prediction:  4
NEW ERROR:
target class:  4
prediction:  3
NEW ERROR:
target class:  4
prediction:  1
NEW ERROR:
target class:  8
prediction:  9
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  6
prediction:  2
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  9
prediction:  6
NEW ERROR:
target class:  7
prediction:  1
NEW ERROR:
target class:  2
prediction:  1
NEW ERROR:
target class:  8
prediction:  1
NEW ERROR:


NEW ERROR:
target class:  6
prediction:  1
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  0
prediction:  2
NEW ERROR:
target class:  8
prediction:  3
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:
target class:  3
prediction:  6
NEW ERROR:
target class:  5
prediction:  1
NEW ERROR:
target class:  6
prediction:  2
NEW ERROR:
target class:  9
prediction:  3
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  5
prediction:  9
NEW ERROR:
target class:  8
prediction:  2
NEW ERROR:
target class:  7
prediction:  9
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  2
prediction:  7
NEW ERROR:
target class:  1
prediction:  6
NEW ERROR:
target class:  5
prediction:  6
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  9
prediction:  1
NEW ERROR:
target class:  5
prediction:  9
NEW ERROR:
target class:  6
prediction:  5
NEW ERROR:
target class:  9
prediction:  5
NEW ERROR:


NEW ERROR:
target class:  6
prediction:  1
NEW ERROR:
target class:  6
prediction:  9
NEW ERROR:
target class:  7
prediction:  1
NEW ERROR:
target class:  2
prediction:  1
NEW ERROR:
target class:  6
prediction:  1
NEW ERROR:
target class:  2
prediction:  1
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  5
prediction:  9
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  4
prediction:  1
NEW ERROR:
target class:  0
prediction:  5
NEW ERROR:
target class:  8
prediction:  6
NEW ERROR:
target class:  4
prediction:  3
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  3
prediction:  2
NEW ERROR:
target class:  6
prediction:  2
NEW ERROR:
target class:  8
prediction:  1
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  8
prediction:  1
NEW ERROR:
target class:  2
prediction:  1
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  0
prediction:  1
NEW ERROR:


1359

In [23]:
compute_nb_errors(model_14px, test_input_14px, test_target_14px, 100)

NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  4
prediction:  2
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  2
prediction:  0
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  0
prediction:  7
NEW ERROR:
target class:  7
prediction:  0
NEW ERROR:
target class:  1
prediction:  3
NEW ERROR:
target class:  6
prediction:  7
NEW ERROR:
target class:  1
prediction:  3
NEW ERROR:
target class:  1
prediction:  5
NEW ERROR:
target class:  6
prediction:  7
NEW ERROR:
target class:  3
prediction:  8
NEW ERROR:
target class:  4
prediction:  2
NEW ERROR:
target class:  1
prediction:  2
NEW ERROR:
target class:  4
prediction:  6
NEW ERROR:
target class:  8
prediction:  2
NEW ERROR:
target class:  7
prediction:  9
NEW ERROR:
target class:  1
prediction:  0
NEW ERROR:
target class:  9
prediction:  0
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  6
prediction:  3
NEW ERROR:


prediction:  4
NEW ERROR:
target class:  6
prediction:  8
NEW ERROR:
target class:  3
prediction:  5
NEW ERROR:
target class:  5
prediction:  8
NEW ERROR:
target class:  4
prediction:  6
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  6
prediction:  7
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  9
prediction:  2
NEW ERROR:
target class:  9
prediction:  3
NEW ERROR:
target class:  7
prediction:  1
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  8
prediction:  6
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  8
prediction:  2
NEW ERROR:
target class:  0
prediction:  8
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  0
prediction:  9
NEW ERROR:
target class:  3
prediction:  7
NEW ERROR:
target class:  9
prediction:  0
NEW ERROR:
target class:  0
prediction:  6
NEW ERROR:
target class:  9
prediction:

prediction:  3
NEW ERROR:
target class:  3
prediction:  7
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  2
prediction:  0
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  6
prediction:  3
NEW ERROR:
target class:  9
prediction:  5
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  9
prediction:  5
NEW ERROR:
target class:  2
prediction:  3
NEW ERROR:
target class:  5
prediction:  3
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  9
prediction:  3
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  3
prediction:  7
NEW ERROR:
target class:  5
prediction:  6
NEW ERROR:
target class:  5
prediction:  7
NEW ERROR:
target class:  9
prediction:  2
NEW ERROR:
target class:  2
prediction:  7
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  4
prediction:  0
NEW ERROR:
target class:  0
prediction:

target class:  9
prediction:  2
NEW ERROR:
target class:  2
prediction:  0
NEW ERROR:
target class:  7
prediction:  0
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  3
prediction:  7
NEW ERROR:
target class:  2
prediction:  8
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  6
prediction:  2
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  4
prediction:  6
NEW ERROR:
target class:  0
prediction:  7
NEW ERROR:
target class:  9
prediction:  0
NEW ERROR:
target class:  2
prediction:  7
NEW ERROR:
target class:  8
prediction:  3
NEW ERROR:
target class:  9
prediction:  2
NEW ERROR:
target class:  7
prediction:  2
NEW ERROR:
target class:  1
prediction:  4
NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  6
prediction:  1
NEW ERROR:
target class:  2
prediction:  7
NEW ERROR:
target class:  8
prediction:  6
NEW ERROR:
target class:  2
prediction:  0
NEW ERROR:
target class:  7
prediction:  4
NEW ERROR:
target clas

prediction:  7
NEW ERROR:
target class:  4
prediction:  8
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  1
prediction:  0
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  8
prediction:  3
NEW ERROR:
target class:  8
prediction:  3
NEW ERROR:
target class:  9
prediction:  2
NEW ERROR:
target class:  8
prediction:  3
NEW ERROR:
target class:  8
prediction:  6
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  8
prediction:  2
NEW ERROR:
target class:  3
prediction:  9
NEW ERROR:
target class:  1
prediction:  0
NEW ERROR:
target class:  7
prediction:  8
NEW ERROR:
target class:  3
prediction:  2
NEW ERROR:
target class:  0
prediction:  3
NEW ERROR:
target class:  6
prediction:  7
NEW ERROR:
target class:  4
prediction:  9
NEW ERROR:
target class:  5
prediction:  8
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  5
prediction:  2
NEW ERROR:
target class:  6
prediction:

NEW ERROR:
target class:  7
prediction:  3
NEW ERROR:
target class:  1
prediction:  6
NEW ERROR:
target class:  7
prediction:  0
NEW ERROR:
target class:  3
prediction:  2
NEW ERROR:
target class:  1
prediction:  7
NEW ERROR:
target class:  9
prediction:  7
NEW ERROR:
target class:  5
prediction:  7
NEW ERROR:
target class:  7
prediction:  8
NEW ERROR:
target class:  1
prediction:  9
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  8
prediction:  7
NEW ERROR:
target class:  3
prediction:  8
NEW ERROR:
target class:  0
prediction:  7
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  2
prediction:  7
NEW ERROR:
target class:  1
prediction:  0
NEW ERROR:
target class:  8
prediction:  7
NEW ERROR:
target class:  6
prediction:  7
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  8
prediction:  5
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  7
prediction:  9
NEW ERROR:


NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  4
prediction:  0
NEW ERROR:
target class:  3
prediction:  0
NEW ERROR:
target class:  0
prediction:  7
NEW ERROR:
target class:  0
prediction:  7
NEW ERROR:
target class:  8
prediction:  7
NEW ERROR:
target class:  5
prediction:  0
NEW ERROR:
target class:  1
prediction:  6
NEW ERROR:
target class:  6
prediction:  8
NEW ERROR:
target class:  4
prediction:  8
NEW ERROR:
target class:  1
prediction:  5
NEW ERROR:
target class:  9
prediction:  6
NEW ERROR:
target class:  0
prediction:  8
NEW ERROR:
target class:  8
prediction:  0
NEW ERROR:
target class:  1
prediction:  3
NEW ERROR:
target class:  0
prediction:  2
NEW ERROR:
target class:  4
prediction:  0
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  4
prediction:  3
NEW ERROR:
target class:  5
prediction:  7
NEW ERROR:
target class:  9
prediction:  0
NEW ERROR:


target class:  0
prediction:  3
NEW ERROR:
target class:  8
prediction:  5
NEW ERROR:
target class:  6
prediction:  0
NEW ERROR:
target class:  4
prediction:  0


1799

In [36]:
nb_epochs = 20
test_error = torch.empty(nb_epochs)
train_error = torch.empty(nb_epochs)


for e in range(nb_epochs):
    print("epoch ", e)
    model_14px = Net_14px()
    model_14px = train_model(model_14px, nn.CrossEntropyLoss(), optim.SGD(model_14px.parameters(), lr=0.1), e*5, 100, \
                         train_input_14px, train_target_14px, verbose=False)
    train_error[e] = compute_nb_errors(model_14px, train_input_14px, train_target_14px, 100)
    test_error[e] = compute_nb_errors(model_14px, test_input_14px, test_target_14px, 100)

epoch  0
epoch  1
epoch  2
epoch  3
epoch  4
epoch  5
epoch  6
epoch  7
epoch  8
epoch  9
epoch  10
epoch  11
epoch  12
epoch  13
epoch  14
epoch  15
epoch  16
epoch  17
epoch  18
epoch  19


In [37]:
train_error

tensor([1819., 1723., 1668., 1659., 1585., 1544., 1573., 1498., 1343., 1330.,
        1277., 1172., 1085., 1240., 1122., 1109., 1230., 1142., 1032., 1040.])

In [38]:
test_error

tensor([1802., 1777., 1785., 1773., 1770., 1776., 1792., 1791., 1803., 1783.,
        1783., 1804., 1810., 1804., 1814., 1807., 1799., 1819., 1783., 1788.])

In [42]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
plt.plot(train_error.numpy())

ModuleNotFoundError: No module named 'matplotlib'