In [1]:
%load_ext autoreload

In [2]:
%autoreload 2

# Подготовка

## Датасет
Для тестирования LeNet-5 можно использовать датасет FashionMNIT, так как изображения в нем размера $28 \times 28$, как раз подходит под вход без ресайза

In [3]:
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

train_data = datasets.MNIST(
    root='../data',
    train=True,
    download=True,
    transform=transforms.ToTensor(),
)

test_data = datasets.MNIST(
    root='../data',
    train=False,
    download=True,
    transform=transforms.ToTensor(),
)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ../data/MNIST/raw/train-images-idx3-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/train-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ../data/MNIST/raw/train-labels-idx1-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/train-labels-idx1-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw/t10k-images-idx3-ubyte.gz


100.0%


Extracting ../data/MNIST/raw/t10k-images-idx3-ubyte.gz to ../data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100.0%

Extracting ../data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ../data/MNIST/raw






## Проверка использования GPU

In [4]:
torch.cuda.is_available()

True

In [5]:
torch.cuda.current_device()

0

In [6]:
torch.cuda.get_device_name()

'NVIDIA GeForce GTX 1070 Ti'

## Подготовка модели и лоадеров

In [8]:
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size)
test_loader = DataLoader(test_data, batch_size=batch_size)

In [29]:
from model.lenet5 import LeNet5
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from evaluation.eval import ClassificationEvaluator
from evaluation.metrics import Metric

model = LeNet5('../config/lenet-5.yaml', nc=10)
optimizer = SGD(model.parameters(), lr=0.1, momentum=0.9)
loss_fn = CrossEntropyLoss()
evaluator = ClassificationEvaluator([Metric.Accuracy, Metric.Precision, Metric.Recall, Metric.F1])

In [30]:
from trainer.supervised import ClassificationTrainer

trainer = ClassificationTrainer(model, train_loader, test_loader, loss_fn, optimizer, evaluator, eval_freq=10000)

In [31]:
model.nn

Sequential(
  (0): Conv(
    (conv): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False)
    (act): ReLU()
  )
  (1): AvgPool(
    (avg): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (2): Conv(
    (conv): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), bias=False)
    (act): ReLU()
  )
  (3): AvgPool(
    (avg): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (4): Flatten(start_dim=1, end_dim=-1)
  (5): FC(
    (fc): Linear(in_features=400, out_features=120, bias=True)
    (act): ReLU()
  )
  (6): FC(
    (fc): Linear(in_features=120, out_features=84, bias=True)
    (act): ReLU()
  )
  (7): FC(
    (fc): Linear(in_features=84, out_features=10, bias=True)
    (act): ReLU()
  )
  (8): Softmax(
    (sm): Softmax(dim=1)
  )
)

# Тестирование

In [None]:
epochs = 100

for epoch in range(epochs):
    trainer.train()

Metrics for id 0:
Accuracy | Precision | Recall | F1
0.1028   | 0.1028    | 0.1028 | 0.1028
Loss for id 0: 0.030923862087726593
Metrics for id 0:
Accuracy | Precision | Recall | F1
0.8055   | 0.8055    | 0.8055 | 0.8054999999999999
