### імпорт бібліотек та функцій

In [0]:
from torch import tensor
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torch.autograd import Variable
from torch.utils.data import DataLoader, sampler

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid

import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.metrics import accuracy_score

def imshow(img):
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.rcParams["figure.figsize"] = (10,2)
  plt.show()

### створюємо клас нейроної мережі

In [0]:
class LeNet5(nn.Module):          
    def __init__(self):
      super(LeNet5, self).__init__()
      self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
      self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
      self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1)
      self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
      self.fc1 = nn.Linear(16 * 5 * 5, 120)
      self.fc2 = nn.Linear(120, 84)
      self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
      x = self.pool1(F.relu(self.conv1(x)))
      x = self.pool2(F.relu(self.conv2(x)))
      x = x.view(-1, 16 * 5 * 5)
      x = F.relu(self.fc1(x))
      x = F.relu(self.fc2(x))
      x = self.fc3(x)
      return x
     
net = LeNet5()
net.cuda()

### вибір констант, функції втрат та функції оптимізації

In [0]:
numEpochs = 5
batch_size = 512
num_workers = 4

loss_func = nn.CrossEntropyLoss()
optimization = Adam(net.parameters(), lr=0.001)

### створення наборів для навчання та тестування

In [0]:
train = MNIST(root='./data', train=True, download=True, transform=ToTensor())
test = MNIST(root='./data', train=False, download=True, transform=ToTensor())  

train_loader = DataLoader(train, batch_size=batch_size, num_workers=num_workers)   
test_loader = DataLoader(test, batch_size=batch_size, num_workers=num_workers)       

### навчання

In [0]:
for epoch in range(numEpochs):
    for batch_num, training_batch in enumerate(train_loader):
        inputs, labels = training_batch
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())        
        optimization.zero_grad()         
        forward_output = net(inputs)
        loss = loss_func(forward_output, labels)
        loss.backward()
        optimization.step()

    accuracy = 0.0 
    total = 0
    for batch_num, test_batch in enumerate(test_loader):
        total += 1
        inputs, actual_val = test_batch
        predicted_val = net(Variable(inputs.cuda()))
        predicted_val = predicted_val.cpu().data.numpy()
        predicted_val = np.argmax(predicted_val, axis = 1)  
        accuracy += accuracy_score(actual_val.numpy(), predicted_val)

    print("epoch: ", epoch, " — accuracy: {:.2f}".format(accuracy / total))





epoch:  0  — accuracy: 0.88
epoch:  1  — accuracy: 0.91
epoch:  2  — accuracy: 0.93
epoch:  3  — accuracy: 0.95
epoch:  4  — accuracy: 0.96


### показ погано класифікованих зображень

In [0]:
for idx, (inputs, labels) in enumerate(test_loader):
    inputss = inputs.numpy()
    inputs = inputs.cuda()

    predicted = net(inputs).cpu().detach()
    predicted = np.argmax(predicted, axis=-1)
    label_np = labels.numpy()
    not_ok = predicted != labels

    good = defaultdict(list)
    for i, e in enumerate(labels):
      if predicted[i] == e:
        good[int(e)].append(inputss[i])

    bad = inputss[not_ok]
    bad_l = label_np[not_ok]
    bad_p = predicted[not_ok]

    if len(bad) > 9:
      bad = bad[:10]
      bad_l_t = [test.classes[x] for x in list(bad_l[:10])]
      bad_p_t = [test.classes[x] for x in list(bad_p[:10])]
      for i in range(10):
        imshow(make_grid(tensor(bad[i]), nrow=10, padding=2))
        print("Predicted:", bad_p_t[i])
        print("Label    :", bad_l_t[i])
        print('–' * 100)
        print(f"Good samples of {bad_l_t[i]}:")
        imshow(make_grid(tensor(good[bad_l[i]][:10]), nrow=10, padding=2))
        print('–' * (100 * (i != 9)))
    if idx == 0:
      break