<a href="https://colab.research.google.com/github/xzdil/dogorcat/blob/mnist/%D0%A0%D0%B0%D1%81%D0%BF%D0%BE%D0%B7%D0%BD%D0%B0%D0%B2%D0%B0%D0%BD%D0%B8%D0%B5_%D1%86%D0%B8%D1%84%D1%80.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn
from torch import optim
from torchvision import datasets, transforms 
from torch.utils.data import random_split, DataLoader

In [None]:
train_data = datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor())
train, val = random_split(train_data, (55000, 5000))
train_loader = DataLoader(train, batch_size = 32)
val_loader = DataLoader(val, batch_size = 32)  

In [None]:
model = nn.Sequential(
    nn.Linear(28 * 28, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(64, 10)
)

In [None]:
optimiser = optim.SGD(model.parameters(), lr = 1e-2)

In [None]:
loss = nn.CrossEntropyLoss()

In [None]:
nb_epochs = 5
for epoch in range(nb_epochs):
    losses = list()
    accuracies = list()
    model.train()
    for batch in train_loader:
        x,y = batch
        b = x.size(0)
        x = x.view(b, -1)

        l = model(x)

        j = loss(l, y)

        model.zero_grad()

        j.backward()
        
        optimiser.step()

        losses.append(j.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())

    print(f'Epoch {epoch + 1},', end=', ')
    print(f'train loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'train accuracy: {torch.tensor(accuracies).mean():.2f}')
    losses = list()
    accuracies = list()
    model.eval()
    for batch in val_loader:
        x,y = batch
        b = x.size(0)
        x = x.view(b, -1)
        with torch.no_grad():
            l = model(x)
        j = loss(l, y)
        losses.append(j.item())
        accuracies.append(y.eq(l.detach().argmax(dim=1)).float().mean())
    print(f'Epoch {epoch + 1},', end=', ')
    print(f'val loss: {torch.tensor(losses).mean():.2f}', end=', ')
    print(f'val accuracy: {torch.tensor(accuracies).mean():.2f}')

Epoch 1,, train loss: 0.18, train accuracy: 0.95
Epoch 1,, val loss: 0.17, val accuracy: 0.95
Epoch 2,, train loss: 0.17, train accuracy: 0.95
Epoch 2,, val loss: 0.16, val accuracy: 0.95
Epoch 3,, train loss: 0.16, train accuracy: 0.95
Epoch 3,, val loss: 0.16, val accuracy: 0.95
Epoch 4,, train loss: 0.15, train accuracy: 0.96
Epoch 4,, val loss: 0.15, val accuracy: 0.95
Epoch 5,, train loss: 0.14, train accuracy: 0.96
Epoch 5,, val loss: 0.15, val accuracy: 0.96
