# Subspace constrained CNN MNIST

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader

## Data

In [2]:
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    # torchvision.transforms.Lambda(lambda x: torch.flatten(x))
])

In [3]:
train = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    # natively stored as PIL images
    transform=dataset_transform
)

In [4]:
test = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    train=False,
    transform=dataset_transform
)

In [5]:
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# Returns (torch.Size([100, 1, 28, 28]), torch.Size([100]))

In [6]:
test_loader = DataLoader(test, batch_size=500, shuffle=False)

## Net definition

In [7]:
from net import SubspaceConv2d, SubspaceLinear
from torch.nn.parameter import Parameter

In [8]:
intrinsic_dim = 100

In [9]:
class SubspaceConstrainedLeNet(nn.Module):
    def __init__(self):
        """
        Subspace constrained version of PyImageSearch's LeNet implementation
        """
        super().__init__()
        
        self.theta = Parameter(torch.empty((intrinsic_dim, 1)))
        self.theta.data.fill_(0)
        
        self.conv1 = SubspaceConv2d(self.theta, in_channels=1, out_channels=20, kernel_size=(5, 5), stride=1)
        self.relu1 = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.conv2 = SubspaceConv2d(self.theta, in_channels=20, out_channels=50, kernel_size=(5, 5), stride=1)
        self.relu2 = nn.ReLU()
        self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
        
        self.flatten1 = nn.Flatten()
        
        self.fc1 = SubspaceLinear(self.theta, in_features=800, out_features=500)
        self.relu3 = nn.ReLU()
        
        self.fc2 = SubspaceLinear(self.theta, in_features=500, out_features=10)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.maxpool2(x)
        
        x = self.flatten1(x)
        x = self.fc1(x)
        x = self.relu3(x)
        x = self.fc2(x)
        x = self.logsoftmax(x)
        
        return x

## Training

In [10]:
net = SubspaceConstrainedLeNet()
opt = torch.optim.Adam(net.parameters(), lr=1e-2)
num_epochs = 5

In [11]:
loss_history = []
acc_history = []

In [12]:
# Train
net.train()

for _ in range(num_epochs):
    for batch_id, (features, target) in enumerate(train_loader):
        # forward pass, calculate loss and backprop!
        # features.shape: 100, 1, 28, 28
        opt.zero_grad()
        preds = net(features)
        loss = F.nll_loss(preds, target)
        loss.backward()
        loss_history.append(loss.item())
        opt.step()

        if batch_id % 100 == 0:
            print(loss.item())

2.3108291625976562
2.2809832096099854
2.279046058654785
2.269514799118042
2.2809548377990723
2.2762463092803955
2.261420488357544
2.255058765411377
2.246279716491699
2.179980516433716
2.120734930038452
1.877232551574707
1.5798449516296387
1.2773282527923584
1.0793684720993042
1.2323675155639648
1.1547588109970093
0.9760134816169739
1.0293822288513184
0.9469051957130432
0.9689795970916748
0.9759196639060974
1.149322271347046
1.015008807182312
0.8724449276924133
0.8221571445465088
0.8196656703948975
0.9515354633331299
0.9104640483856201
0.9277781844139099


In [13]:
# Test
net.eval()

test_loss = 0
correct = 0

for features, target in test_loader:
    output = net(features)
    test_loss += F.nll_loss(output, target).item()
    pred = torch.argmax(output, dim=-1) # get the index of the max log-probability
    correct += pred.eq(target).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    accuracy))


Test set: Average loss: 0.8512, Accuracy: 7236/10000 (72%)



Looks set to run in script.