# Image classification with modern MLP models

## Introduction

This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image
classification, demonstrated on the CIFAR-100 dataset:

1. The [MLP-Mixer](https://arxiv.org/abs/2105.01601) model, by Ilya Tolstikhin et al., based on two types of MLPs.
3. The [FNet](https://arxiv.org/abs/2105.03824) model, by James Lee-Thorp et al., based on unparameterized
Fourier Transform.
2. The [gMLP](https://arxiv.org/abs/2105.08050) model, by Hanxiao Liu et al., based on MLP with gating.

The purpose of the example is not to compare between these models, as they might perform differently on
different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their
main building blocks.

## Setup

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms
import numpy as np
from sklearn.metrics import accuracy_score

# Set device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Prepare the data

In [2]:
num_classes = 100
input_shape = (3, 32, 32)  # PyTorch uses (C, H, W) format

# Define data transformations
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))  # Normalize with CIFAR-100 stats
])

# Load CIFAR-100 dataset
train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

# Print dataset shapes
print(f"x_train shape: {train_dataset.data.shape} - y_train shape: {np.array(train_dataset.targets).shape}")
print(f"x_test shape: {test_dataset.data.shape} - y_test shape: {np.array(test_dataset.targets).shape}")

Files already downloaded and verified
Files already downloaded and verified
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000,)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000,)


<!--  -->

## Configure the hyperparameters

In [3]:
# Hyperparameters
weight_decay = 0.0001
batch_size = 128
num_epochs = 1  # Recommended num_epochs = 50
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 8  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of blocks.

# Print hyperparameters
print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2}")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")

Image size: 64 X 64 = 4096
Patch size: 8 X 8 = 64
Patches per image: 64
Elements per patch (3 channels): 192


## Use data augmentation

In [4]:
# Compute mean and variance of the training data
x_train_np = train_dataset.data / 255.0  # Normalize pixel values to [0, 1]
mean = np.mean(x_train_np, axis=(0, 1, 2))  # Mean across channels
std = np.std(x_train_np, axis=(0, 1, 2))  # Std across channels

print(f"Mean: {mean}")
print(f"Std: {std}")

data_augmentation = transforms.Compose([
    transforms.Resize((image_size, image_size)),  # Resize images to the desired size
    transforms.RandomHorizontalFlip(),  # Randomly flip images horizontally
    transforms.RandomAffine(degrees=0, scale=(0.8, 1.2)),  # Random zoom
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize(mean, std)  # Normalize with computed mean and std
])

# Apply data augmentation to the training dataset
train_dataset = datasets.CIFAR100(
    root='./data',
    train=True,
    download=True,
    transform=data_augmentation  # Apply data augmentation
)

# For the test dataset, use only resizing and normalization (no augmentation)
test_dataset = datasets.CIFAR100(
    root='./data',
    train=False,
    download=True,
    transform=transforms.Compose([
        transforms.Resize((image_size, image_size)),  # Resize images to the desired size
        transforms.ToTensor(),  # Convert images to PyTorch tensors
        transforms.Normalize(mean, std)  # Normalize with computed mean and std
    ])
)

Mean: [0.50707516 0.48654887 0.44091784]
Std: [0.26733429 0.25643846 0.27615047]
Files already downloaded and verified
Files already downloaded and verified


## Implement patch extraction as a layer

In [5]:
class Patches(nn.Module):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def forward(self, x):
        """
        Extracts patches from the input image tensor.

        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            torch.Tensor: Tensor of shape (B, num_patches, patch_dim), where
                          num_patches = (H // patch_size) * (W // patch_size),
                          patch_dim = C * patch_size * patch_size.
        """
        # Extract patches using PyTorch's unfold operation
        B, C, H, W = x.shape
        patches = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(B, C, -1, self.patch_size, self.patch_size)  # (B, C, num_patches, patch_size, patch_size)
        patches = patches.permute(0, 2, 1, 3, 4)  # (B, num_patches, C, patch_size, patch_size)
        patches = patches.reshape(B, -1, C * self.patch_size * self.patch_size)  # (B, num_patches, patch_dim)

        return patches

## Implement position embedding as a layer

In [6]:
class PositionEmbedding(nn.Module):
    def __init__(self, sequence_length, feature_size, initializer=None):
        """
        Args:
            sequence_length (int): Length of the sequence.
            feature_size (int): Size of the feature dimension.
            initializer (callable): Initializer for the position embeddings. If None, uses default PyTorch initialization.
        """
        super(PositionEmbedding, self).__init__()
        self.sequence_length = sequence_length
        self.feature_size = feature_size

        # Initialize position embeddings
        self.position_embeddings = nn.Parameter(torch.zeros(sequence_length, feature_size))
        if initializer is not None:
            initializer(self.position_embeddings)

    def forward(self, inputs, start_index=0):
        """
        Args:
            inputs (torch.Tensor): Input tensor of shape (B, sequence_length, feature_size).
            start_index (int): Index from which to start slicing the position embeddings.

        Returns:
            torch.Tensor: Output tensor of shape (B, sequence_length, feature_size).
        """
        B, seq_len, feature_size = inputs.shape

        # Trim position embeddings to match the input sequence length
        position_embeddings = self.position_embeddings[start_index : start_index + seq_len, :]

        # Broadcast position embeddings to match the input shape
        position_embeddings = position_embeddings.unsqueeze(0).expand(B, -1, -1)

        return inputs + position_embeddings

## Build a classification model

We implement a method that builds a classifier given the processing blocks.

In [7]:
class Classifier(nn.Module):
    def __init__(self, blocks, patch_size, embedding_dim, num_patches, num_classes, dropout_rate, positional_encoding=False):
        """
        Args:
            blocks (nn.Module): A PyTorch module that processes the input tensor.
            patch_size (int): Size of the patches.
            embedding_dim (int): Dimensionality of the patch embeddings.
            num_patches (int): Number of patches per image.
            num_classes (int): Number of output classes.
            dropout_rate (float): Dropout rate.
            positional_encoding (bool): Whether to add positional encoding.
        """
        super(Classifier, self).__init__()
        self.patch_size = patch_size
        self.embedding_dim = embedding_dim
        self.num_patches = num_patches
        self.num_classes = num_classes
        self.dropout_rate = dropout_rate
        self.positional_encoding = positional_encoding

        # Layers
        self.patches_layer = Patches(patch_size)
        self.dense = nn.Linear((patch_size ** 2) * 3, embedding_dim)
        if positional_encoding:
            self.position_embedding = PositionEmbedding(num_patches, embedding_dim)
        self.blocks = blocks
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(embedding_dim, num_classes)

    def forward(self, x):
        # Create patches
        patches = self.patches_layer(x)  # (B, num_patches, patch_dim)
        # Encode patches
        x = self.dense(patches)  # (B, num_patches, embedding_dim)
        # Add positional encoding
        if self.positional_encoding:
            x = x + self.position_embedding(x)  # (B, num_patches, embedding_dim)
        # Process through blocks
        x = self.blocks(x)  # (B, num_patches, embedding_dim)
        # Global average pooling
        x = x.permute(0, 2, 1)  # (B, embedding_dim, num_patches)
        x = self.global_pool(x)  # (B, embedding_dim, 1)
        x = x.squeeze(2)  # (B, embedding_dim)
        # Apply dropout
        x = self.dropout(x)
        # Compute logits
        logits = self.classifier(x)  # (B, num_classes)
        return logits

## Define an experiment

We implement a utility function to compile, train, and evaluate a given model.

In [8]:
def run_experiment(model, learning_rate, train_dataset, test_dataset, batch_size, num_epochs, weight_decay):
    """
    Args:
        model (nn.Module): The PyTorch model to train.
        learning_rate (float): Learning rate for the optimizer.
        train_dataset (torch.utils.data.Dataset): Training dataset.
        test_dataset (torch.utils.data.Dataset): Test dataset.
        batch_size (int): Batch size for training and evaluation.
        num_epochs (int): Number of epochs to train.
        weight_decay (float): Weight decay for the optimizer.
    """
    # Set device (GPU if available, else CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Split training data into training and validation sets
    train_size = int(0.9 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    # Create DataLoader objects
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Define optimizer with weight decay
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Define learning rate scheduler
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    # Define loss function
    criterion = nn.CrossEntropyLoss()

    # Early stopping parameters
    best_val_loss = float('inf')
    patience = 10
    patience_counter = 0

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        model.eval()
        val_loss = 0.0
        val_acc = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                # Calculate accuracy
                _, predicted = torch.max(outputs, 1)
                val_acc += accuracy_score(labels.cpu(), predicted.cpu())

        # Average losses and accuracy
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        val_acc /= len(val_loader)

        # Print epoch results
        print(f"Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss:.4f}, "
              f"Val Loss: {val_loss:.4f}, "
              f"Val Acc: {val_acc:.4f}")

        # Learning rate scheduling
        scheduler.step(val_loss)

        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save the best model
            torch.save(model.state_dict(), "best_model.pth")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    # Load the best model
    model.load_state_dict(torch.load("best_model.pth"))

    # Evaluate on the test set
    model.eval()
    test_acc = 0.0
    test_top5_acc = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            # Calculate top-1 accuracy
            _, predicted = torch.max(outputs, 1)
            test_acc += accuracy_score(labels.cpu(), predicted.cpu())

            # Calculate top-5 accuracy
            _, top5_predicted = torch.topk(outputs, 5, dim=1)
            top5_correct = torch.sum(top5_predicted == labels.view(-1, 1))
            test_top5_acc += top5_correct.item() / labels.size(0)

    # Average test accuracy
    test_acc /= len(test_loader)
    test_top5_acc /= len(test_loader)

    # Print test results
    print(f"Test Accuracy: {test_acc * 100:.2f}%")
    print(f"Test Top-5 Accuracy: {test_top5_acc * 100:.2f}%")

## The MLP-Mixer model

The MLP-Mixer is an architecture based exclusively on
multi-layer perceptrons (MLPs), that contains two types of MLP layers:

1. One applied independently to image patches, which mixes the per-location features.
2. The other applied across patches (along channels), which mixes spatial information.

This is similar to a [depthwise separable convolution based model](https://arxiv.org/abs/1610.02357)
such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization
instead of batch normalization.

### Implement the MLP-Mixer module

In [9]:
class MLPMixerLayer(nn.Module):
    def __init__(self, num_patches, hidden_units, dropout_rate):
        """
        Args:
            num_patches (int): Number of patches.
            hidden_units (int): Number of hidden units in the MLP.
            dropout_rate (float): Dropout rate.
        """
        super(MLPMixerLayer, self).__init__()

        # MLP applied across patches (mixing spatial information)
        self.mlp1 = nn.Sequential(
            nn.Linear(num_patches, num_patches),  # First dense layer
            nn.GELU(),  # GELU activation
            nn.Linear(num_patches, num_patches),  # Second dense layer
            nn.Dropout(dropout_rate)  # Dropout
        )

        # MLP applied across channels (mixing per-location features)
        self.mlp2 = nn.Sequential(
            nn.Linear(hidden_units, hidden_units),  # First dense layer
            nn.GELU(),  # GELU activation
            nn.Linear(hidden_units, hidden_units),  # Second dense layer
            nn.Dropout(dropout_rate)  # Dropout
        )

        # Layer normalization
        self.normalize = nn.LayerNorm(hidden_units, eps=1e-6)

    def forward(self, inputs):
        """
        Args:
            inputs (torch.Tensor): Input tensor of shape (B, num_patches, hidden_units).

        Returns:
            torch.Tensor: Output tensor of shape (B, num_patches, hidden_units).
        """
        # Apply layer normalization
        x = self.normalize(inputs)  # (B, num_patches, hidden_units)

        # Transpose inputs to mix across patches
        x_channels = x.transpose(1, 2)  # (B, hidden_units, num_patches)

        # Apply mlp1 on each channel independently
        mlp1_outputs = self.mlp1(x_channels)  # (B, hidden_units, num_patches)

        # Transpose back to original shape
        mlp1_outputs = mlp1_outputs.transpose(1, 2)  # (B, num_patches, hidden_units)

        # Add skip connection
        x = mlp1_outputs + inputs  # (B, num_patches, hidden_units)

        # Apply layer normalization
        x_patches = self.normalize(x)  # (B, num_patches, hidden_units)

        # Apply mlp2 on each patch independently
        mlp2_outputs = self.mlp2(x_patches)  # (B, num_patches, hidden_units)

        # Add skip connection
        x = x + mlp2_outputs  # (B, num_patches, hidden_units)

        return x

### Build, train, and evaluate the MLP-Mixer model

In [10]:
# Create the MLP-Mixer blocks
mlpmixer_blocks = nn.Sequential(
    *[MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)

# Create the classifier model
mlpmixer_classifier = Classifier(
    blocks=mlpmixer_blocks,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_patches=num_patches,
    num_classes=num_classes,
    dropout_rate=dropout_rate,
    positional_encoding=True
)

# Define learning rate
learning_rate = 0.005

# Run the experiment
run_experiment(
    model=mlpmixer_classifier,
    learning_rate=learning_rate,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    weight_decay=weight_decay
)

Epoch 1/1, Train Loss: 3.9079, Val Loss: 3.6088, Val Acc: 0.1701


  model.load_state_dict(torch.load("best_model.pth"))


Test Accuracy: 17.20%
Test Top-5 Accuracy: 42.14%


The MLP-Mixer model tends to have much less number of parameters compared
to convolutional and transformer-based models, which leads to less training and
serving computational cost.

As mentioned in the [MLP-Mixer](https://arxiv.org/abs/2105.01601) paper,
when pre-trained on large datasets, or with modern regularization schemes,
the MLP-Mixer attains competitive scores to state-of-the-art models.
You can obtain better results by increasing the embedding dimensions,
increasing the number of mixer blocks, and training the model for longer.
You may also try to increase the size of the input images and use different patch sizes.

## The FNet model

The FNet uses a similar block to the Transformer block. However, FNet replaces the self-attention layer
in the Transformer block with a parameter-free 2D Fourier transformation layer:

1. One 1D Fourier Transform is applied along the patches.
2. One 1D Fourier Transform is applied along the channels.

### Implement the FNet module

In [11]:
class FNetLayer(nn.Module):
    def __init__(self, embedding_dim, dropout_rate):
        """
        Args:
            embedding_dim (int): Dimensionality of the input and output features.
            dropout_rate (float): Dropout rate.
        """
        super(FNetLayer, self).__init__()

        # Feedforward network
        self.ffn = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim),  # First dense layer
            nn.GELU(),  # GELU activation
            nn.Dropout(dropout_rate),  # Dropout
            nn.Linear(embedding_dim, embedding_dim)  # Second dense layer
        )

        # Layer normalization
        self.normalize1 = nn.LayerNorm(embedding_dim, eps=1e-6)
        self.normalize2 = nn.LayerNorm(embedding_dim, eps=1e-6)

    def forward(self, inputs):
        """
        Args:
            inputs (torch.Tensor): Input tensor of shape (B, num_patches, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of shape (B, num_patches, embedding_dim).
        """
        # Apply 2D Fourier Transform
        x = torch.fft.fft2(inputs).real  # (B, num_patches, embedding_dim)

        # Add skip connection
        x = x + inputs  # (B, num_patches, embedding_dim)

        # Apply layer normalization
        x = self.normalize1(x)  # (B, num_patches, embedding_dim)

        # Apply feedforward network
        x_ffn = self.ffn(x)  # (B, num_patches, embedding_dim)

        # Add skip connection
        x = x + x_ffn  # (B, num_patches, embedding_dim)

        # Apply layer normalization
        x = self.normalize2(x)  # (B, num_patches, embedding_dim)

        return x

### Build, train, and evaluate the FNet model

In [12]:
# Create the FNet blocks
fnet_blocks = nn.Sequential(
    *[FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]
)

# Create the classifier model
fnet_classifier = Classifier(
    blocks=fnet_blocks,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_patches=num_patches,
    num_classes=num_classes,
    dropout_rate=dropout_rate,
    positional_encoding=True
)

# Define learning rate
learning_rate = 0.001

# Run the experiment
run_experiment(
    model=fnet_classifier,
    learning_rate=learning_rate,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    weight_decay=weight_decay
)

Epoch 1/1, Train Loss: 4.1169, Val Loss: 3.8115, Val Acc: 0.1178


  model.load_state_dict(torch.load("best_model.pth"))


Test Accuracy: 11.55%
Test Top-5 Accuracy: 32.89%


As shown in the [FNet](https://arxiv.org/abs/2105.03824) paper,
better results can be achieved by increasing the embedding dimensions,
increasing the number of FNet blocks, and training the model for longer.
You may also try to increase the size of the input images and use different patch sizes.
The FNet scales very efficiently to long inputs, runs much faster than attention-based
Transformer models, and produces competitive accuracy results.

## The gMLP model

The gMLP is a MLP architecture that features a Spatial Gating Unit (SGU).
The SGU enables cross-patch interactions across the spatial (channel) dimension, by:

1. Transforming the input spatially by applying linear projection across patches (along channels).
2. Applying element-wise multiplication of the input and its spatial transformation.

### Implement the gMLP module

In [13]:
class gMLPLayer(nn.Module):
    def __init__(self, num_patches, embedding_dim, dropout_rate):
        """
        Args:
            num_patches (int): Number of patches.
            embedding_dim (int): Dimensionality of the input and output features.
            dropout_rate (float): Dropout rate.
        """
        super(gMLPLayer, self).__init__()

        # Channel projection 1
        self.channel_projection1 = nn.Sequential(
            nn.Linear(embedding_dim, embedding_dim * 2),  # First dense layer
            nn.GELU(),  # GELU activation
            nn.Dropout(dropout_rate)  # Dropout
        )

        # Channel projection 2
        self.channel_projection2 = nn.Linear(embedding_dim, embedding_dim)

        # Spatial projection
        self.spatial_projection = nn.Linear(num_patches, num_patches)

        # Layer normalization
        self.normalize1 = nn.LayerNorm(embedding_dim, eps=1e-6)
        self.normalize2 = nn.LayerNorm(embedding_dim, eps=1e-6)

    def spatial_gating_unit(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, num_patches, embedding_dim * 2).

        Returns:
            torch.Tensor: Output tensor of shape (B, num_patches, embedding_dim).
        """
        # Split x along the channel dimension
        u, v = torch.split(x, x.size(2) // 2, dim=2)  # (B, num_patches, embedding_dim)

        # Apply layer normalization
        v = self.normalize2(v)  # (B, num_patches, embedding_dim)

        # Apply spatial projection
        v = v.transpose(1, 2)  # (B, embedding_dim, num_patches)
        v_projected = self.spatial_projection(v)  # (B, embedding_dim, num_patches)
        v_projected = v_projected.transpose(1, 2)  # (B, num_patches, embedding_dim)

        # Apply element-wise multiplication
        return u * v_projected  # (B, num_patches, embedding_dim)

    def forward(self, inputs):
        """
        Args:
            inputs (torch.Tensor): Input tensor of shape (B, num_patches, embedding_dim).

        Returns:
            torch.Tensor: Output tensor of shape (B, num_patches, embedding_dim).
        """
        # Apply layer normalization
        x = self.normalize1(inputs)  # (B, num_patches, embedding_dim)

        # Apply the first channel projection
        x_projected = self.channel_projection1(x)  # (B, num_patches, embedding_dim * 2)

        # Apply the spatial gating unit
        x_spatial = self.spatial_gating_unit(x_projected)  # (B, num_patches, embedding_dim)

        # Apply the second channel projection
        x_projected = self.channel_projection2(x_spatial)  # (B, num_patches, embedding_dim)

        # Add skip connection
        return inputs + x_projected  # (B, num_patches, embedding_dim)

### Build, train, and evaluate the gMLP model

In [14]:
# Create the gMLP blocks
gmlp_blocks = nn.Sequential(
    *[gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)

# Create the classifier model
gmlp_classifier = Classifier(
    blocks=gmlp_blocks,
    patch_size=patch_size,
    embedding_dim=embedding_dim,
    num_patches=num_patches,
    num_classes=num_classes,
    dropout_rate=dropout_rate,
    positional_encoding=True
)

# Define learning rate
learning_rate = 0.003

# Run the experiment
run_experiment(
    model=gmlp_classifier,
    learning_rate=learning_rate,
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    batch_size=batch_size,
    num_epochs=num_epochs,
    weight_decay=weight_decay
)

Epoch 1/1, Train Loss: 3.9089, Val Loss: 3.6911, Val Acc: 0.1506


  model.load_state_dict(torch.load("best_model.pth"))


Test Accuracy: 17.46%
Test Top-5 Accuracy: 42.45%


As shown in the [gMLP](https://arxiv.org/abs/2105.08050) paper,
better results can be achieved by increasing the embedding dimensions,
increasing the number of gMLP blocks, and training the model for longer.
You may also try to increase the size of the input images and use different patch sizes.
Note that, the paper used advanced regularization strategies, such as MixUp and CutMix,
as well as AutoAugment.