### LeNet

In [None]:
import torch

def default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')   
    if torch.backends.mps.is_available():
        return torch.device('mps')
    return torch.device('cpu')

device = default_device()

In [None]:
import torch.nn as nn
from torchinfo import summary

In [None]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
        self.fc1 = nn.Linear(16*4*4, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(-1, 16*4*4)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
summary(LeNet().to(device), (1, 28, 28))

### Train

In [None]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import *
import numpy as np
import sys

torch.manual_seed(0)
np.random.seed(0)

model = LeNet().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.02)
criterion = nn.CrossEntropyLoss()

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)

test_dataset = datasets.MNIST('data', train=False, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

n_epochs = 10
losses_history = []
accuracies_history = []

for epoch in tqdm(range(n_epochs), file=sys.stdout):
    total_loss = 0
    total_correct = 0

    model.train()
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(x)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
   

    model.eval()
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            y_pred = model(x)
            total_correct += (y_pred.argmax(1) == y).type(torch.float).sum().item()
    
    losses_history.append(np.log10(total_loss))
    accuracies_history.append(total_correct / len(test_loader.dataset))

    if epoch % 2 == 0:
        tqdm.write(f'Epoch {epoch}, Loss: {total_loss:.2f}, Accuracy: {total_correct / len(test_loader.dataset):.2f}')
    



In [None]:
import matplotlib.pyplot as plt

plt.plot(losses_history, label='Loss')
plt.plot(accuracies_history, label='Accuracy')
plt.legend()
plt.show()

print('Accuracy:', accuracies_history[-1])