In [15]:
import torch
from torch import nn, optim
from torchvision import datasets
from torchvision.transforms import v2

In [39]:
# transform = v2.Compose(
#     [
#         v2.ToTensor(),
#         v2.Normalize((0.5,), (0.5,))
#     ]
# )

transforms = v2.Compose([
    # v2.RandomResizedCrop(size=(28, 28), antialias=True),
    # v2.RandomHorizontalFlip(p=0.5),
    v2.ToImage(),  # Преобразование в изображение
    v2.ToDtype(torch.float32, scale=True),  # Преобразование в тензор с масштабированием
    v2.Normalize((0.5,), (0.5,))
])
mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transforms)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms)

In [40]:
train_loader = torch.utils.data.DataLoader(mnist, batch_size=1000, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=1000, shuffle=False)

In [45]:
class MNISTClassification(nn.Module):
    def __init__(self):
        super(MNISTClassification, self).__init__()
        self.linear = nn.Linear(28*28, 10)
    
    def forward(self, image):
        out = image.view(-1, 28*28)
        out = self.linear(out)
        return out
    
mnist_model = MNISTClassification()
criterion = nn.CrossEntropyLoss()
optimizer  = optim.SGD(mnist_model.parameters(), lr=0.05)

In [46]:
from tqdm import tqdm

num_epochs = 20
for epoch in range(num_epochs):
    mnist_model.train()
    with tqdm(total=len(train_loader), desc=f'Epoch [{epoch+1:>2}/{num_epochs}]') as pbar:
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = mnist_model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            pbar.set_postfix({'mnist loss': f'{loss.item():.4f}'})
            pbar.update(1)

mnist_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        outputs = mnist_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
print(f'Accuracy of the model on the test images: {100 * correct / total:.2f}%')

Epoch [ 1/20]: 100%|██████████| 60/60 [00:29<00:00,  2.05it/s, mnist loss=0.6243]
Epoch [ 2/20]: 100%|██████████| 60/60 [00:29<00:00,  2.04it/s, mnist loss=0.4723]
Epoch [ 3/20]: 100%|██████████| 60/60 [00:24<00:00,  2.44it/s, mnist loss=0.4490]
Epoch [ 4/20]: 100%|██████████| 60/60 [00:24<00:00,  2.49it/s, mnist loss=0.4255]
Epoch [ 5/20]: 100%|██████████| 60/60 [00:24<00:00,  2.47it/s, mnist loss=0.3591]
Epoch [ 6/20]: 100%|██████████| 60/60 [00:24<00:00,  2.48it/s, mnist loss=0.3475]
Epoch [ 7/20]: 100%|██████████| 60/60 [00:24<00:00,  2.49it/s, mnist loss=0.3571]
Epoch [ 8/20]: 100%|██████████| 60/60 [00:23<00:00,  2.50it/s, mnist loss=0.3098]
Epoch [ 9/20]: 100%|██████████| 60/60 [00:23<00:00,  2.50it/s, mnist loss=0.3500]
Epoch [10/20]: 100%|██████████| 60/60 [00:23<00:00,  2.52it/s, mnist loss=0.3420]
Epoch [11/20]: 100%|██████████| 60/60 [00:23<00:00,  2.54it/s, mnist loss=0.3232]
Epoch [12/20]: 100%|██████████| 60/60 [00:23<00:00,  2.51it/s, mnist loss=0.3687]
Epoch [13/20]: 1

Accuracy of the model on the test images: 91.63%
