<a href="https://colab.research.google.com/github/polevev/kaggle/blob/main/ideal_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt

# Загружаем MNIST
transform = transforms.ToTensor()
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

dataloader = {
    "train": train_dataloader,
    "valid": val_dataloader
}

# Функция для вычисления validation accuracy
def compute_validation_accuracy(model, val_dataloader):
    model.eval()
    epoch_correct = 0
    epoch_all = 0

    with torch.no_grad():
        for x_batch, y_batch in val_dataloader:
            outp = model(x_batch)
            preds = outp.argmax(dim=1)
            correct = (preds == y_batch).float().sum()
            epoch_correct += correct.item()
            epoch_all += len(y_batch)

    val_acc = epoch_correct / epoch_all if epoch_all > 0 else 0
    return val_acc

# Функция для создания модели с заданной функцией активации
def create_model(activation):
    return nn.Sequential(
        nn.Flatten(),
        nn.Linear(784, 128),
        activation,
        nn.Linear(128, 10),
    )

# Список функций активации
activations = {
    'ELU': nn.ELU(),
    'ReLU': nn.ReLU(),
    'Sigmoid': nn.Sigmoid()
}

# Обучение и сбор validation accuracy для каждой функции активации
max_epochs = 10
validation_accuracies = {name: [] for name in activations.keys()}

for activation_name, activation in activations.items():
    print(f"\nTraining with {activation_name} activation...")

    # Создаем модель
    model = create_model(activation)

    # Оптимизатор и функция потерь
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()

    # Обучение
    for epoch in range(max_epochs):
        model.train()
        for x_batch, y_batch in dataloader['train']:
            optimizer.zero_grad()
            outp = model(x_batch)
            loss = criterion(outp, y_batch)
            loss.backward()
            optimizer.step()

        # Вычисляем validation accuracy после каждой эпохи
        val_acc = compute_validation_accuracy(model, dataloader['valid'])
        validation_accuracies[activation_name].append(val_acc)
        print(f"[EPOCH]: {epoch+1}, Validation Accuracy: {val_acc:.4f}")

# Построение графика
plt.figure(figsize=(10, 6))
for activation_name, val_accs in validation_accuracies.items():
    plt.plot(range(1, max_epochs + 1), val_accs, label=activation_name)

plt.xlabel('Epoch')
plt.ylabel('Validation Accuracy')
plt.title('Validation Accuracy vs Epoch for Different Activation Functions')
plt.legend()
plt.grid(True)
plt.show()