# MNIST with limited intrinsic dimension

The math in the paper doesn't seem too daunting. Hope I'm not wrong.

Normal training is taking the gradient of the loss wrt the weights in its native dimension D, and then perturbing the weights within that space. 

Training within a random subspace is 

$$ \theta ^ D = \theta_0^D + P \theta ^ d  $$

Where:
+ $\theta ^ D$ is the weights in its native dimension
+ $\theta_0 ^ D$ is the initialized and frozen weights
+ $P$ is a randomly generated $D \times d$ projection matrix, frozen
+ $\theta ^ d$ is the effective weights in a downscaled dimension, that is updated on every loop.

Cobbled together a regular MNIST training loop, referring to:
+ https://gist.github.com/kdubovikov/eb2a4c3ecadd5295f68c126542e59f0a
+ https://github.com/uber-research/intrinsic-dimension/blob/master/intrinsic_dim/model_builders.py#L81
+ https://arxiv.org/pdf/1804.08838.pdf

## Figuring out how to incorporate the subspace training logic

Reading the code on a second pass, found that the individual layers are wrapped in custom projected layers, e.g,:

```python
# ref: https://github.com/uber-research/intrinsic-dimension/blob/master/intrinsic_dim/model_builders.py
from keras_ext.layers import RProjDense, OffsetCreatorDenseProj
from keras_ext.engine import ExtendedModel
offset_creator_class = OffsetCreatorDenseProj

# this chunk is pulled from definition of `build_model_mnist_fc`
for _ in range(depth):
    xx = RProjDense(
        offset_creator_class, vv, width, activation='relu', 
        kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay)
    )(xx)

logits = RProjDense(
    offset_creator_class, vv, 10, kernel_initializer='he_normal', 
    kernel_regularizer=l2(weight_decay)
)(xx)
model = ExtendedModel(input=input_images, output=logits)
```

The whole reason there's so much code lying around is probably due to the custom layers needed for keras. Reading the definitions in https://github.com/uber-research/intrinsic-dimension/blob/master/keras_ext/rproj_layers.py, the custom layers subclass the keras Layer object, modifying `add_weight` and `add_non_trainable_weight` to allow for the subspace training operation. There is a custom layer for each normal layer used! Dense, Dense2D, BatchNorm etc.

This is a lot of customization that needs to be piled on before using it, finding intrinsic dimension being a iterative process notwithstanding.

I suppose using Keras for research is bound to require shims and splints. Subclassing the Pytorch layers is doable too, but hopefully there is another way to approach it. 

In [30]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import math
import torchvision
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline

## Data

In [3]:
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Lambda(lambda x: torch.flatten(x))
])

In [4]:
train = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    # natively stored as PIL images
    transform=dataset_transform
)

In [5]:
test = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    train=False,
    transform=dataset_transform
)

In [6]:
train

Dataset MNIST
    Number of datapoints: 60000
    Root location: /home/tnwei/.torchdata/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [7]:
test

Dataset MNIST
    Number of datapoints: 10000
    Root location: /home/tnwei/.torchdata/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [8]:
train.data.shape

torch.Size([60000, 28, 28])

In [9]:
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# Returns (torch.Size([100, 784]), torch.Size([100]))

In [10]:
test_loader = DataLoader(test, batch_size=500, shuffle=False)

## Custom blocks with subspace training - Dense

Referred to how torch.nn.Linear is implemeneted at https://pytorch.org/docs/stable/_modules/torch/nn/modules/linear.html#Linear, I think it is possible to implement a drop-in replacement layer for `nn.Linear`.

In [31]:
class SubspaceLinear(nn.Module):
    def __init__(
        self,
        in_features, out_features, 
        subspace_features, # this is new!
        bias: bool = True, # the rest is by the numbers
        device = None,
        dtype = None
    ):
        factory_kwargs = {"device": device, "dtype": dtype}

        super().__init__() 
        
        # Mirror nn.Linear init
        self.in_features = in_features
        self.out_features = out_features
        self.subspace_features = subspace_features
        
        # Not a Parameter!
        self.theta = torch.empty((out_features, in_features), **factory_kwargs)
        
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)

        nn.init.kaiming_uniform_(self.theta, a=math.sqrt(5))

        # After init, save fixed weights
        self.theta_zero = self.theta.detach().clone()
        
        # Generate projection matrix
        self.proj_mat = torch.empty((out_features, subspace_features), **factory_kwargs)
        nn.init.kaiming_uniform_(self.proj_mat, a=math.sqrt(5))
        # TODO: Init this properly
        
        # Init theta prime, which will be actually used
        self.theta_prime = Parameter(torch.empty((subspace_features, in_features), **factory_kwargs))
    
        # According to https://pytorch.org/docs/stable/generated/torch.nn.functional.linear.html
        # Weight has shape (out_features, in_features)
        # Therefore P x theta_prime is:
        # (out_features, subspace_features) X (subspace_features, in_features)
        
        self.reset_parameters()
        
    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        
        nn.init.kaiming_uniform_(self.theta_prime, a=math.sqrt(5))
        
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.theta)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)
        
    def forward(self, x):
        # in nn.Linear:
        # return F.linear(x, self.weight, self.bias)
        theta = self.theta_zero + torch.mm(self.proj_mat, self.theta_prime)
        return F.linear(x, theta, self.bias)
    
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, subspace_features={}, bias={}'.format(
            self.in_features, self.out_features, self.subspace_features, self.bias is not None
        )

In [32]:
sslinear = SubspaceLinear(in_features=100, out_features=2, subspace_features=10)

In [33]:
for i, j in sslinear.named_parameters():
    print(i, j.shape)

bias torch.Size([2])
theta_prime torch.Size([10, 100])


In [34]:
sslinear.theta_zero.shape, sslinear.proj_mat.shape, sslinear.theta_prime.shape

(torch.Size([2, 100]), torch.Size([2, 10]), torch.Size([10, 100]))

In [35]:
(torch.mm(sslinear.proj_mat, sslinear.theta_prime) + sslinear.theta_zero).shape

torch.Size([2, 100])

The repo README refers to three random projection types: denseproj, sparseproj, and fastfoodproj. The dense projection code does not have specific inits for the theta prime tensor. So, subspace projection here should be good ... ? There must be a reason why fastfood transform exists, but that's for later.

## Combine into net

In [95]:
class SubspaceConstrainedMNIST(nn.Module):
    def __init__(self):
        """
        Paper uses 784-200-200-10
        ref: https://arxiv.org/pdf/1804.08838.pdf
        
        Ref in github:
        https://github.com/uber-research/intrinsic-dimension/blob/9754ebe1954e82973c7afe280d2c59850f281dca/intrinsic_dim/model_builders.py#L81
        """
        super().__init__()
        self.hidden1 = SubspaceLinear(784, 200, subspace_features=1)
        self.hidden2 = SubspaceLinear(200, 10, subspace_features=1)
        
    def forward(self, x):
        x = self.hidden1(x)
        x = F.relu(x)
        x = self.hidden2(x)
        x = F.relu(x)
        x = F.log_softmax(x, dim=-1)  # (batch_size, dims)
        return x

## Training

In [96]:
net = SubspaceConstrainedMNIST()
opt = torch.optim.Adam(net.parameters(), lr=1e-4)
num_epochs = 5

In [97]:
for i, j in net.named_parameters():
    print(i, j.shape)

hidden1.bias torch.Size([200])
hidden1.theta_prime torch.Size([1, 784])
hidden2.bias torch.Size([10])
hidden2.theta_prime torch.Size([1, 200])


In [98]:
for j in net.parameters():
    print(j.shape)

torch.Size([200])
torch.Size([1, 784])
torch.Size([10])
torch.Size([1, 200])


In [99]:
loss_history = []
acc_history = []

In [100]:
# Train
net.train()
first_ten_flag = False

for _ in range(num_epochs):
    for batch_id, (features, target) in enumerate(train_loader):
        # forward pass, calculate loss and backprop!
        opt.zero_grad()
        preds = net(features)
        loss = F.nll_loss(preds, target)
        loss.backward()
        loss_history.append(loss.item())
        opt.step()

        if (batch_id % 100 == 0):
            print(batch_id, loss.item())

0 2.302417278289795
100 2.2474958896636963
200 2.229177951812744
300 2.1434550285339355
400 2.1148269176483154
500 1.9995251893997192
0 1.9329489469528198
100 1.9272170066833496
200 1.9117159843444824
300 1.8802971839904785
400 1.8392847776412964
500 1.906437635421753
0 1.8706496953964233
100 1.8790347576141357
200 1.8126376867294312
300 1.8047877550125122
400 1.8428906202316284
500 1.8741843700408936
0 1.8446906805038452
100 1.802445888519287
200 1.7921345233917236
300 1.837571382522583
400 1.795865774154663
500 1.8554120063781738
0 1.8762071132659912
100 1.7685798406600952
200 1.7978787422180176
300 1.8438409566879272
400 1.708366870880127
500 1.735275387763977


The network was giving NaN loss. Forked into 2.1 to troubleshoot, turns out I forgot to init the projection matrix. Once put that in, the network seems to be doing OK.

## Test

In [101]:
net.eval()

test_loss = 0
correct = 0

for features, target in test_loader:
    output = net(features)
    test_loss += F.nll_loss(output, target).item()
    pred = torch.argmax(output, dim=-1) # get the index of the max log-probability
    correct += pred.eq(target).cpu().sum()

test_loss = test_loss
test_loss /= len(test_loader) # loss function already averages over batch size
accuracy = 100. * correct / len(test_loader.dataset)
acc_history.append(accuracy)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    accuracy))


Test set: Average loss: 1.7814, Accuracy: 3328/10000 (33%)



Varied the subspace dimension a few times but not much change. Realized that at subspace features = 1, there is still many degrees of freedom left! 

In [102]:
for i, j in net.named_parameters():
    print(i, j.shape)

hidden1.bias torch.Size([200])
hidden1.theta_prime torch.Size([1, 784])
hidden2.bias torch.Size([10])
hidden2.theta_prime torch.Size([1, 200])


$1x784 + 1x200 = 984$

Gotta see how the repo constrains degrees of freedom properly.