In [53]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision as tv

import numpy as np

import matplotlib.pyplot as plt

import cv2

from tqdm import tqdm

In [54]:
trans = tv.transforms.Compose([
    tv.transforms.ToTensor()
])

In [55]:
ds_mnist = tv.datasets.MNIST('./datasets', download=True, transform=trans)

In [56]:
batch_size = 16
dataloader = torch.utils.data.DataLoader(
    ds_mnist, batch_size=batch_size, shuffle=True,
    num_workers=1, drop_last=True
)

In [57]:
class Neural_numbers(nn.Module):
  def __init__(self):
    super().__init__()
    self.flat = nn.Flatten()
    self.linear1 = nn.Linear(28*28, 100)
    self.linear2 = nn.Linear(100, 10)
    self.act = nn.ReLU()

  def forward(self, x):
    out = self.flat(x)
    out = self.linear1(out)
    out = self.act(out)
    out = self.linear2(out)

    return out

In [58]:
def count_parameters(model):
  return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [59]:
model = Neural_numbers()

In [60]:
loss_fn = nn.CrossEntropyLoss()

In [61]:
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)

In [62]:
def accuracy(pred, label):
  answer = F.softmax(pred.detach()).numpy().argmax(1) == label.numpy().argmax(1)
  
  return answer.mean()

In [65]:
epochs = 100

for epoch in range(epochs):
  loss_val = 0
  acc_val = 0

  for img, label in (pbar := tqdm(dataloader)):
    optimizer.zero_grad()
    
    label = F.one_hot(label, 10).float()
    pred = model(img)
    
    loss = loss_fn(pred, label)

    loss.backward()
    loss_item = loss.item()
    loss_val += loss_item

    optimizer.step()

    acc_current = accuracy(pred, label)
    acc_val += acc_current

    pbar.set_description(f'loss: {loss_item:.5f}\taccuracy: {acc_current:.3f}')
  
  print(f'epoch: {epoch}')
  print(f'loss: {loss_val/len(dataloader)}')
  print(f'accuracy: {acc_val/len(dataloader)}')

  answer = F.softmax(pred.detach()).numpy().argmax(1) == label.numpy().argmax(1)
loss: 0.31266	accuracy: 0.875: 100%|██████████| 3750/3750 [00:24<00:00, 151.99it/s]


epoch: 0
loss: 0.17053015242132047
accuracy: 0.9516166666666667


loss: 0.13475	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 152.34it/s]


epoch: 1
loss: 0.1548178101297468
accuracy: 0.95615


loss: 0.13922	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 153.34it/s]


epoch: 2
loss: 0.14170461939312518
accuracy: 0.9604166666666667


loss: 0.38499	accuracy: 0.875: 100%|██████████| 3750/3750 [00:24<00:00, 152.11it/s]


epoch: 3
loss: 0.1310062714148313
accuracy: 0.9632833333333334


loss: 0.18418	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 152.90it/s]


epoch: 4
loss: 0.12133336493801325
accuracy: 0.9660666666666666


loss: 0.03690	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 152.10it/s]


epoch: 5
loss: 0.11334055300230782
accuracy: 0.96825


loss: 0.02310	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 151.79it/s]


epoch: 6
loss: 0.1059578050242737
accuracy: 0.9709333333333333


loss: 0.20840	accuracy: 0.875: 100%|██████████| 3750/3750 [00:24<00:00, 151.94it/s]


epoch: 7
loss: 0.09946865261209507
accuracy: 0.9721333333333333


loss: 0.01364	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 151.44it/s]


epoch: 8
loss: 0.09393292232823247
accuracy: 0.97395


loss: 0.00636	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 152.71it/s]


epoch: 9
loss: 0.08871966523372879
accuracy: 0.9754


loss: 0.00671	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 154.25it/s]


epoch: 10
loss: 0.08425946731970956
accuracy: 0.97665


loss: 0.05459	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 152.37it/s]


epoch: 11
loss: 0.08032308638201406
accuracy: 0.97795


loss: 0.13537	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 153.92it/s]


epoch: 12
loss: 0.07623609357727691
accuracy: 0.9790166666666666


loss: 0.01044	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 153.44it/s]


epoch: 13
loss: 0.07307159437729667
accuracy: 0.9796666666666667


loss: 0.05835	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 152.59it/s]


epoch: 14
loss: 0.06972291625691578
accuracy: 0.9809166666666667


loss: 0.05542	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 149.28it/s]


epoch: 15
loss: 0.06693270948845893
accuracy: 0.98145


loss: 0.05694	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 152.45it/s]


epoch: 16
loss: 0.06428707931232638
accuracy: 0.9822833333333333


loss: 0.03200	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 153.37it/s]


epoch: 17
loss: 0.06175511827065299
accuracy: 0.9831


loss: 0.03555	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 152.59it/s]


epoch: 18
loss: 0.05939342238469981
accuracy: 0.9840333333333333


loss: 0.07592	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 153.12it/s]


epoch: 19
loss: 0.057061632897208135
accuracy: 0.9843666666666666


loss: 0.08260	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 153.48it/s]


epoch: 20
loss: 0.054911394364014265
accuracy: 0.9851666666666666


loss: 0.01005	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 152.05it/s]


epoch: 21
loss: 0.05309695739879583
accuracy: 0.9855166666666667


loss: 0.03700	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 151.75it/s]


epoch: 22
loss: 0.050895067826394615
accuracy: 0.9863833333333333


loss: 0.34186	accuracy: 0.875: 100%|██████████| 3750/3750 [00:24<00:00, 151.03it/s]


epoch: 23
loss: 0.04942451947174656
accuracy: 0.9870333333333333


loss: 0.01099	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 151.68it/s]


epoch: 24
loss: 0.04781841117682246
accuracy: 0.98705


loss: 0.02214	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.84it/s]


epoch: 25
loss: 0.04620730792921192
accuracy: 0.98765


loss: 0.02231	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 148.73it/s]


epoch: 26
loss: 0.0447257146738004
accuracy: 0.9884


loss: 0.00596	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 148.14it/s]


epoch: 27
loss: 0.043088238556815
accuracy: 0.9890166666666667


loss: 0.00698	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 148.76it/s]


epoch: 28
loss: 0.041802201279702904
accuracy: 0.98945


loss: 0.04660	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 148.68it/s]


epoch: 29
loss: 0.040537132599173735
accuracy: 0.9898166666666667


loss: 0.10147	accuracy: 0.938: 100%|██████████| 3750/3750 [00:24<00:00, 150.18it/s]


epoch: 30
loss: 0.039179385595116765
accuracy: 0.9902333333333333


loss: 0.00347	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 149.56it/s]


epoch: 31
loss: 0.038127581865076594
accuracy: 0.9903


loss: 0.00949	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 144.98it/s]


epoch: 32
loss: 0.03699143051189215
accuracy: 0.9908


loss: 0.01151	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.66it/s]


epoch: 33
loss: 0.0359674870531618
accuracy: 0.9911333333333333


loss: 0.01497	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 148.12it/s]


epoch: 34
loss: 0.03491640205963437
accuracy: 0.9912166666666666


loss: 0.00353	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.71it/s]


epoch: 35
loss: 0.033839926854815953
accuracy: 0.9915333333333334


loss: 0.02802	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.69it/s]


epoch: 36
loss: 0.03282178219150131
accuracy: 0.99205


loss: 0.00578	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 147.58it/s]


epoch: 37
loss: 0.031919975662747554
accuracy: 0.9923833333333333


loss: 0.00636	accuracy: 1.000: 100%|██████████| 3750/3750 [00:24<00:00, 150.40it/s]


epoch: 38
loss: 0.030985280186411304
accuracy: 0.9926


loss: 0.00418	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.91it/s]


epoch: 39
loss: 0.03007725480160055
accuracy: 0.9929166666666667


loss: 0.00260	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 147.35it/s]


epoch: 40
loss: 0.02913962810112086
accuracy: 0.9933166666666666


loss: 0.05160	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 147.23it/s]


epoch: 41
loss: 0.02840547002524448
accuracy: 0.9934166666666666


loss: 0.00509	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.32it/s]


epoch: 42
loss: 0.02741524630115212
accuracy: 0.9941333333333333


loss: 0.00153	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.31it/s]


epoch: 43
loss: 0.02689634705770295
accuracy: 0.9939333333333333


loss: 0.00713	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.58it/s]


epoch: 44
loss: 0.026259785314451438
accuracy: 0.9942166666666666


loss: 0.03450	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.50it/s]


epoch: 45
loss: 0.025481143877158562
accuracy: 0.9945


loss: 0.00473	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.00it/s]


epoch: 46
loss: 0.024771329783966456
accuracy: 0.9947833333333334


loss: 0.00560	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 146.16it/s]


epoch: 47
loss: 0.024071517437312288
accuracy: 0.9947833333333334


loss: 0.00102	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.62it/s]


epoch: 48
loss: 0.023383012517237026
accuracy: 0.9953166666666666


loss: 0.00228	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.76it/s]


epoch: 49
loss: 0.02283256435643804
accuracy: 0.99545


loss: 0.00335	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.37it/s]


epoch: 50
loss: 0.02219324853330618
accuracy: 0.9955


loss: 0.00453	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 140.09it/s]


epoch: 51
loss: 0.02174785926550782
accuracy: 0.9956666666666667


loss: 0.00760	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.47it/s]


epoch: 52
loss: 0.021138586425955872
accuracy: 0.99595


loss: 0.01719	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.96it/s]


epoch: 53
loss: 0.02060043709326031
accuracy: 0.9961666666666666


loss: 0.02414	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.84it/s]


epoch: 54
loss: 0.020143762370123296
accuracy: 0.9963666666666666


loss: 0.00418	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.01it/s]


epoch: 55
loss: 0.019591135274576177
accuracy: 0.9964


loss: 0.00223	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.40it/s]


epoch: 56
loss: 0.01903134730005792
accuracy: 0.9966166666666667


loss: 0.00060	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.93it/s]


epoch: 57
loss: 0.018596255315283392
accuracy: 0.9966166666666667


loss: 0.00224	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 140.34it/s]


epoch: 58
loss: 0.018190589748245355
accuracy: 0.9969166666666667


loss: 0.03943	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.98it/s]


epoch: 59
loss: 0.017844628928875318
accuracy: 0.9971333333333333


loss: 0.00739	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.39it/s]


epoch: 60
loss: 0.01724101734860451
accuracy: 0.9971333333333333


loss: 0.00108	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.60it/s]


epoch: 61
loss: 0.01683112473359021
accuracy: 0.9976


loss: 0.00686	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.25it/s]


epoch: 62
loss: 0.016470868600689574
accuracy: 0.99725


loss: 0.00367	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 143.79it/s]


epoch: 63
loss: 0.01605681997036057
accuracy: 0.9976333333333334


loss: 0.04746	accuracy: 0.938: 100%|██████████| 3750/3750 [00:26<00:00, 143.47it/s]


epoch: 64
loss: 0.015741378479504298
accuracy: 0.9975833333333334


loss: 0.00965	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.33it/s]


epoch: 65
loss: 0.015296293843609358
accuracy: 0.9977833333333334


loss: 0.00720	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.41it/s]


epoch: 66
loss: 0.014985624304814458
accuracy: 0.9981166666666667


loss: 0.02025	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.22it/s]


epoch: 67
loss: 0.014630690635536545
accuracy: 0.9981166666666667


loss: 0.03783	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 144.60it/s]


epoch: 68
loss: 0.014348873496965583
accuracy: 0.9981333333333333


loss: 0.00495	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.90it/s]


epoch: 69
loss: 0.014045679359649269
accuracy: 0.9982166666666666


loss: 0.03813	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.30it/s]


epoch: 70
loss: 0.013668960869693547
accuracy: 0.9983166666666666


loss: 0.02866	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 140.61it/s]


epoch: 71
loss: 0.013392024453545067
accuracy: 0.9983


loss: 0.00051	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 144.20it/s]


epoch: 72
loss: 0.0130898546293601
accuracy: 0.9984333333333333


loss: 0.00258	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.63it/s]


epoch: 73
loss: 0.012818719208060065
accuracy: 0.99865


loss: 0.00183	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.91it/s]


epoch: 74
loss: 0.012549244513995169
accuracy: 0.9986


loss: 0.00584	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 145.23it/s]


epoch: 75
loss: 0.012215051472977696
accuracy: 0.9987


loss: 0.00219	accuracy: 1.000: 100%|██████████| 3750/3750 [00:25<00:00, 144.67it/s]


epoch: 76
loss: 0.01202339603866858
accuracy: 0.99875


loss: 0.02150	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.84it/s]


epoch: 77
loss: 0.0117356718172911
accuracy: 0.9989666666666667


loss: 0.01249	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 137.46it/s]


epoch: 78
loss: 0.01149685874246934
accuracy: 0.9988666666666667


loss: 0.00120	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.96it/s]


epoch: 79
loss: 0.01126482395027512
accuracy: 0.9989833333333333


loss: 0.01454	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 142.04it/s]


epoch: 80
loss: 0.010988294195582664
accuracy: 0.9989833333333333


loss: 0.00030	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 141.78it/s]


epoch: 81
loss: 0.01075672174852225
accuracy: 0.99905


loss: 0.00477	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 138.66it/s]


epoch: 82
loss: 0.010545064264677058
accuracy: 0.9991166666666667


loss: 0.00518	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.42it/s]


epoch: 83
loss: 0.01036037853757298
accuracy: 0.9990666666666667


loss: 0.00138	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 137.50it/s]


epoch: 84
loss: 0.010142709107679547
accuracy: 0.9990666666666667


loss: 0.00095	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.94it/s]


epoch: 85
loss: 0.009901223547610668
accuracy: 0.9991833333333333


loss: 0.00290	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 138.56it/s]


epoch: 86
loss: 0.00975058176440895
accuracy: 0.9992166666666666


loss: 0.00364	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 135.19it/s]


epoch: 87
loss: 0.009555482046029162
accuracy: 0.99935


loss: 0.01078	accuracy: 1.000: 100%|██████████| 3750/3750 [00:28<00:00, 131.78it/s]


epoch: 88
loss: 0.009336105451724143
accuracy: 0.9993166666666666


loss: 0.00270	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 134.79it/s]


epoch: 89
loss: 0.009163920573033587
accuracy: 0.9993833333333333


loss: 0.08643	accuracy: 0.938: 100%|██████████| 3750/3750 [00:27<00:00, 138.55it/s]


epoch: 90
loss: 0.008954261547887291
accuracy: 0.9993666666666666


loss: 0.00976	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 138.96it/s]


epoch: 91
loss: 0.008843529431941716
accuracy: 0.9994666666666666


loss: 0.00086	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 138.80it/s]


epoch: 92
loss: 0.00862662968638712
accuracy: 0.9994333333333333


loss: 0.00589	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 135.76it/s]


epoch: 93
loss: 0.00843886910067619
accuracy: 0.9994166666666666


loss: 0.00531	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 136.56it/s]


epoch: 94
loss: 0.008324682450539451
accuracy: 0.9995166666666667


loss: 0.02413	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 138.72it/s]


epoch: 95
loss: 0.00817112000513395
accuracy: 0.9994666666666666


loss: 0.00534	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.29it/s]


epoch: 96
loss: 0.008015026809667567
accuracy: 0.9995333333333334


loss: 0.00568	accuracy: 1.000: 100%|██████████| 3750/3750 [00:26<00:00, 139.84it/s]


epoch: 97
loss: 0.007850729510061986
accuracy: 0.99955


loss: 0.01057	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 134.93it/s]


epoch: 98
loss: 0.007714536315012568
accuracy: 0.9995666666666667


loss: 0.00402	accuracy: 1.000: 100%|██████████| 3750/3750 [00:27<00:00, 136.34it/s]

epoch: 99
loss: 0.007608664569630733
accuracy: 0.9996333333333334





In [79]:
img = cv2.imread('img.png', cv2.IMREAD_GRAYSCALE)
img = np.expand_dims(img, axis=0)
img = np.expand_dims(img, axis=0)
img = img.astype(np.float32)/255.0

In [80]:
t_img = torch.from_numpy(img)
our_pred = model(t_img)

In [81]:
F.softmax(our_pred).detach().numpy().argmax()

  F.softmax(our_pred).detach().numpy().argmax()


8