In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
class OrthogonalMLP(nn.Module):
    def __init__(self, input_size=(28, 28), num_classes=10):
        super(OrthogonalMLP, self).__init__()
        
        H, W = input_size
        
        # FC layer applied on Height dimension (H)
        self.fc_h1 = nn.Linear(H, 16)  # Reduce 28 -> 16
        self.fc_h2 = nn.Linear(16, 8)  # Reduce 16 -> 8
        
        # FC layer applied on Width dimension (W)
        self.fc_w1 = nn.Linear(W, 16)  # Reduce 28 -> 16
        self.fc_w2 = nn.Linear(16, 8)  # Reduce 16 -> 8

        # Flatten and final classification layer
        self.final_fc = nn.Linear(8 * 8, num_classes)  # Classification layer for 10 classes

    def forward(self, x):
        # x shape: (batch_size, 1, 28, 28) -> remove channel dimension
        x = x.squeeze(1)  # (batch_size, 28, 28)
        
        # Step 1: Apply FC to height (H) dimension
        x = x.permute(0, 2, 1)  # Swap H and W (batch_size, W, H)
        x = torch.relu(self.fc_h1(x))  # Apply FC on height
        x = x.permute(0, 2, 1)  # Swap back (batch_size, H, W=16)

        # Step 2: Apply FC to width (W) dimension
        x = torch.relu(self.fc_w1(x))  # Apply FC on width (batch_size, H=16, W=16)
        
        # Step 3: Apply FC to height again
        x = x.permute(0, 2, 1)  # Swap H and W (batch_size, W, H=16)
        x = torch.relu(self.fc_h2(x))  # Apply FC on height (batch_size, W, H=8)
        x = x.permute(0, 2, 1)  # Swap back (batch_size, H=8, W=16)

        # Step 4: Apply FC to width again
        x = torch.relu(self.fc_w2(x))  # Apply FC on width (batch_size, H=8, W=8)

        # Flatten and final classification
        x = x.view(x.size(0), -1)  # Flatten (batch_size, 8*8)
        x = self.final_fc(x)  # Classification layer (batch_size, 10)

        return x


In [3]:
# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [4]:
# Initialize model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = OrthogonalMLP().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [5]:
# Training the model
num_epochs = 30

In [6]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")


Epoch 1, Loss: 0.7666
Epoch 2, Loss: 0.3514
Epoch 3, Loss: 0.2776
Epoch 4, Loss: 0.2405
Epoch 5, Loss: 0.2148
Epoch 6, Loss: 0.1963
Epoch 7, Loss: 0.1815
Epoch 8, Loss: 0.1705
Epoch 9, Loss: 0.1621
Epoch 10, Loss: 0.1532
Epoch 11, Loss: 0.1476
Epoch 12, Loss: 0.1394
Epoch 13, Loss: 0.1358
Epoch 14, Loss: 0.1307
Epoch 15, Loss: 0.1274
Epoch 16, Loss: 0.1247
Epoch 17, Loss: 0.1215
Epoch 18, Loss: 0.1190
Epoch 19, Loss: 0.1166
Epoch 20, Loss: 0.1139
Epoch 21, Loss: 0.1127
Epoch 22, Loss: 0.1107
Epoch 23, Loss: 0.1088
Epoch 24, Loss: 0.1070
Epoch 25, Loss: 0.1056
Epoch 26, Loss: 0.1045
Epoch 27, Loss: 0.1018
Epoch 28, Loss: 0.1019
Epoch 29, Loss: 0.1001
Epoch 30, Loss: 0.1001


In [7]:
# Evaluate the model
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [8]:
print(f"Test Accuracy: {100 * correct / total:.2f}%")

# Save the trained model
torch.save(model.state_dict(), 'custom_mnist_model.pth')
print("Model saved successfully.")

Test Accuracy: 96.73%
Model saved successfully.
