In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision 
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [3]:
class FcBlock(nn.Module):
    """
    Processes each row (height dimension) independently.
    Input shape: (B, H, W)
    Output shape: (B, reduced_dim, W)
    """
    def __init__(self, in_dim, out_dim):
        super(FcBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.activation = nn.LeakyReLU()
        
    def forward(self, z):
        h = self.linear(z)  
        h = self.activation(h)
        return h

class TransposeFcBlock(nn.Module):
    """
    Processes each column (width dimension) independently.
    Input shape: (B, reduced_dim, W)
    Output shape: (B, seq_output_dim, reduced_dim)
    """
    def __init__(self, in_dim, out_dim):
        super(TransposeFcBlock, self).__init__()
        self.linear = nn.Linear(in_dim, out_dim)
        self.activation = nn.LeakyReLU()
    
    def forward(self, x):
        # x: (B, H, W)
        x = x.permute(0, 2, 1)  # (B, W, H)
        z = self.linear(x)  # (B, W, reduced_dim)
        z = self.activation(z)
        return z.permute(0, 2, 1)  # (B, reduced_dim, W)

In [6]:
class OrthogonalMLP(nn.Module):
    """
    Combines Horizontal and Vertical MLPs for MNIST processing.
    """
    def __init__(self, in_height, in_width):
        super(OrthogonalMLP, self).__init__()
        self.fc1 = FcBlock(in_width, 16)
        self.tfc1 = TransposeFcBlock(in_height, 16)
        self.fc2 = FcBlock(16, 4)
        self.tfc2 = TransposeFcBlock(16, 4)
        self.ofc = nn.LazyLinear(10)
    
    def forward(self, x):
        # x: (B, H, W)
        z = self.tfc1(self.fc1(x))
        z = self.tfc2(self.fc2(z))
        hat_y = self.ofc(z.flatten(1))
        return hat_y

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B = 64  # Batch size
H, W = 28, 28  # MNIST image dimensions
reduced_dim = 16  # Reduced row dimension
seq_output_dim = 10  # Final output dimension (for classification)
lr = 0.001
epochs = 30

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=B, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=B, shuffle=False)

model = OrthogonalMLP(in_height=H, in_width=W).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    correct = 0
    
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        images = images.squeeze(1)  # Remove channel dimension (B, 1, H, W) -> (B, H, W)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        correct += (outputs.argmax(dim=1) == labels).sum().item()
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}, Accuracy: {correct/len(train_dataset):.4f}")


Epoch 1/30, Loss: 0.9378, Accuracy: 0.6791
Epoch 2/30, Loss: 0.3953, Accuracy: 0.8821
Epoch 3/30, Loss: 0.2898, Accuracy: 0.9135
Epoch 4/30, Loss: 0.2427, Accuracy: 0.9276
Epoch 5/30, Loss: 0.2189, Accuracy: 0.9345
Epoch 6/30, Loss: 0.2042, Accuracy: 0.9394
Epoch 7/30, Loss: 0.1934, Accuracy: 0.9417
Epoch 8/30, Loss: 0.1838, Accuracy: 0.9443
Epoch 9/30, Loss: 0.1768, Accuracy: 0.9465
Epoch 10/30, Loss: 0.1720, Accuracy: 0.9480
Epoch 11/30, Loss: 0.1672, Accuracy: 0.9495
Epoch 12/30, Loss: 0.1624, Accuracy: 0.9514
Epoch 13/30, Loss: 0.1580, Accuracy: 0.9519
Epoch 14/30, Loss: 0.1557, Accuracy: 0.9531
Epoch 15/30, Loss: 0.1525, Accuracy: 0.9539
Epoch 16/30, Loss: 0.1487, Accuracy: 0.9546
Epoch 17/30, Loss: 0.1465, Accuracy: 0.9559
Epoch 18/30, Loss: 0.1448, Accuracy: 0.9551
Epoch 19/30, Loss: 0.1423, Accuracy: 0.9565
Epoch 20/30, Loss: 0.1410, Accuracy: 0.9570
Epoch 21/30, Loss: 0.1392, Accuracy: 0.9580
Epoch 22/30, Loss: 0.1378, Accuracy: 0.9584
Epoch 23/30, Loss: 0.1354, Accuracy: 0.95

In [None]:
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for inputs, targets in test_loader:
        inputs, targets = inputs.to(device), targets.to(device)
        inputs = inputs.squeeze(1)  # Remove channel dimension (B, 1, H, W) -> (B, H, W)
        logits = model(inputs)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == targets).sum().item()
        total += targets.size(0)
accuracy = 100 * correct / total

print(f"Test Accuracy: {accuracy:.2f}%")

Test Accuracy: 95.63%
