In [4]:
from itertools import product

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


In [5]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

In [6]:
batch_sizes = [32, 128, 512]
activation_functions = {'ReLU': nn.ReLU(), 'Sigmoid': nn.Sigmoid(), 'Tanh': nn.Tanh(), 'LeakyReLU': nn.LeakyReLU()}
optimizers = {'SGD': optim.SGD, 'Adam': optim.Adam, 'RMSprop': optim.RMSprop}
num_epochs = [10, 30, 50]
early_stopping_patience = [2, 15]
depths = [2, 4, 6]
widths = [32, 128, 512]
dropouts = [0.2, 0.5]
l2_lambdas = [0.001, 0.01]

def get_data_loader(batch_size):
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    return train_loader, test_loader

In [7]:
def create_mlp(depth, width, activation_fn, dropout=0.0):
    layers = []
    input_size = 28 * 28
    layers.append(nn.Linear(input_size, width))
    layers.append(activation_fn)

    for _ in range(depth - 1):
        if dropout > 0:
            layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(width, width))
        layers.append(activation_fn)

    layers.append(nn.Linear(width, 10))
    return nn.Sequential(*layers)

def train_and_evaluate(depth, width, activation, batch_size, optimizer_name, lr=0.001, dropout=0.0, l2_lambda=0.0, early_stopping_patience=None, max_epochs=50):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    train_loader, test_loader = get_data_loader(batch_size)
    model = create_mlp(depth, width, activation_functions[activation], dropout).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optimizers[optimizer_name](model.parameters(), lr=lr, weight_decay=l2_lambda)

    best_acc = 0.0
    patience_counter = 0

    for epoch in range(max_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.view(images.size(0), -1).to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.view(images.size(0), -1).to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        accuracy = correct / total

        if early_stopping_patience:
            if accuracy > best_acc:
                best_acc = accuracy
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= early_stopping_patience:
                    break

    return best_acc

In [8]:
param_grid = product(depths, widths, activation_functions.keys(), batch_sizes, optimizers.keys(), dropouts, l2_lambdas, num_epochs, early_stopping_patience)
results = [(depth, width, activation, batch_size, optimizer_name, dropout, l2_lambda, epochs, patience,
            train_and_evaluate(depth, width, activation, batch_size, optimizer_name, dropout=dropout, l2_lambda=l2_lambda, early_stopping_patience=patience, max_epochs=epochs))
           for depth, width, activation, batch_size, optimizer_name, dropout, l2_lambda, epochs, patience in param_grid]

for res in results:
    print(f"Depth: {res[0]}, Width: {res[1]}, Activation: {res[2]}, Batch: {res[3]}, Opt: {res[4]}, Dropout: {res[5]}, L2: {res[6]}, Epochs: {res[7]}, Patience: {res[8]}, Acc: {res[9]:.4f}")

KeyboardInterrupt: 