# Подготовка модели распознавания рукописных букв и цифр

In [4]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Compose, Normalize
from torchinfo import summary
from torchvision import datasets

In [5]:
from torchvision.datasets import EMNIST
dataset = EMNIST('data/', 'balanced', download=False)

In [6]:
transform = Compose([
    ToTensor(),
    Normalize([0.5], [0.5])
])

train_dataset = datasets.FashionMNIST('data/', train=True, download=True, transform=transform)
val_dataset = datasets.FashionMNIST('data/', train=False, download=True, transform=transform)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST\raw\train-images-idx3-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████████| 26.4M/26.4M [00:07<00:00, 3.57MB/s]


Extracting data/FashionMNIST\raw\train-images-idx3-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████████████████████████████████████████████████████████████████████████| 29.5k/29.5k [00:00<00:00, 753kB/s]


Extracting data/FashionMNIST\raw\train-labels-idx1-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████████| 4.42M/4.42M [00:00<00:00, 5.32MB/s]


Extracting data/FashionMNIST\raw\t10k-images-idx3-ubyte.gz to data/FashionMNIST\raw

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████████████████| 5.15k/5.15k [00:00<?, ?B/s]

Extracting data/FashionMNIST\raw\t10k-labels-idx1-ubyte.gz to data/FashionMNIST\raw






In [8]:
class CNN(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=32,
                      kernel_size=3,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Flatten(),

            nn.Linear(in_features=6272, out_features=64),
            nn.ReLU(),
            nn.Linear(in_features=64, out_features=n_classes)
        )

    def forward(self,x):
        return self.model(x)

In [9]:
model = CNN(10)

In [12]:
def train(model, optimizer, loss_f, train_loader, val_loader, n_epoch, val_fre):
    model.train()
    for epoch in range(n_epoch):
        loss_sum = 0
        print(f'Epoch: {epoch}')
        for step, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data).squeeze(1)
            loss = loss_f(output, target)
            loss.backward()
            optimizer.step()
            loss_sum += loss.item()
            if step % 10 == 0:
                print(f'Iter: {step} \tLoss: {loss.item()}')
        print(f'Mean Train Loss: {loss_sum / (step + 1):.6f}', end='\n\n')
        if epoch % val_fre == 0:
            validate(model, val_loader)

def validate(model, val_loader):
    model.eval()
    loss_sum = 0
    correct = 0
    for step, (data, target) in enumerate(val_loader):
        with torch.no_grad():
            output = model(data).squeeze(1)
            loss = loss_f(output, target)
        loss_sum += loss.item()
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    acc = correct / len(val_loader.dataset)
    print(f'Val Loss: {loss_sum / (step + 1):.6f} \tAccuracy: {acc}')
    model.train()

In [18]:
loss_f = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-1)

n_epoch = 5
val_fre = 2

In [19]:
train_loader = DataLoader(train_dataset, batch_size=1000, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1000)

In [20]:
train(model, optimizer, loss_f, train_loader, val_loader, n_epoch, val_fre)

Epoch: 0
Iter: 0 	Loss: 0.27292799949645996
Iter: 10 	Loss: 0.29700931906700134
Iter: 20 	Loss: 0.2925114631652832
Iter: 30 	Loss: 0.2900494635105133
Iter: 40 	Loss: 0.3156362473964691
Iter: 50 	Loss: 0.30177170038223267
Mean Train Loss: 0.285626

Val Loss: 0.310513 	Accuracy: 0.8905
Epoch: 1
Iter: 0 	Loss: 0.25417447090148926
Iter: 10 	Loss: 0.26691535115242004
Iter: 20 	Loss: 0.30897217988967896
Iter: 30 	Loss: 0.27155235409736633
Iter: 40 	Loss: 0.29286110401153564
Iter: 50 	Loss: 0.24470190703868866
Mean Train Loss: 0.276570

Epoch: 2
Iter: 0 	Loss: 0.2676997184753418
Iter: 10 	Loss: 0.25573089718818665
Iter: 20 	Loss: 0.2900027930736542
Iter: 30 	Loss: 0.2714817225933075
Iter: 40 	Loss: 0.23893235623836517
Iter: 50 	Loss: 0.2585972547531128
Mean Train Loss: 0.268903

Val Loss: 0.332435 	Accuracy: 0.8796
Epoch: 3
Iter: 0 	Loss: 0.31727132201194763
Iter: 10 	Loss: 0.3202836513519287
Iter: 20 	Loss: 0.24793411791324615
Iter: 30 	Loss: 0.26075279712677
Iter: 40 	Loss: 0.25859478116035

RuntimeError: Parent directory checkpoints does not exist.

In [32]:
torch.save(model.state_dict(), 'checkpoints/cnn.pth')

In [34]:
model.load_state_dict(torch.load('checkpoints/cnn.pth', weights_only=True))


<All keys matched successfully>