# Reimplementing intrinsic dim

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import math
import torchvision
import torch.nn.functional as F
import numpy as np
from torch.nn.parameter import Parameter

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd

## Data

In [3]:
dataset_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(), 
    torchvision.transforms.Lambda(lambda x: torch.flatten(x))
])

In [4]:
train = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    # natively stored as PIL images
    transform=dataset_transform
)

In [5]:
test = torchvision.datasets.MNIST(
    root="~/.torchdata/", download=False, 
    train=False,
    transform=dataset_transform
)

In [6]:
train

Dataset MNIST
    Number of datapoints: 60000
    Root location: /home/tnwei/.torchdata/
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [7]:
test

Dataset MNIST
    Number of datapoints: 10000
    Root location: /home/tnwei/.torchdata/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

In [8]:
train.data.shape

torch.Size([60000, 28, 28])

In [9]:
train_loader = DataLoader(train, batch_size=100, shuffle=True)
# Returns (torch.Size([100, 784]), torch.Size([100]))

In [10]:
test_loader = DataLoader(test, batch_size=500, shuffle=False)

## Reimplementing using only one theta prime

In [11]:
class SubspaceLinear(nn.Module):
    def __init__(
        self,
        in_features, out_features, 
        theta_prime, 
        bias: bool = True, # the rest is by the numbers
        device = None,
        dtype = None
    ):
        factory_kwargs = {"device": device, "dtype": dtype}

        super().__init__() 
        
        # Mirror nn.Linear init
        self.in_features = in_features
        self.out_features = out_features
        self.subspace_features = theta_prime.shape[0] # (intrinsic_dim, 1)
        self.theta_prime = theta_prime
        
        # Weight has shape (out_features, in_features)
        # Therefore P x theta_prime is:
        # (out_features, in_features, subspace_features) X (subspace_features, 1)
        
        # Create and init theta, save theta_zero
        self.theta = torch.empty((out_features, in_features), **factory_kwargs)
        nn.init.kaiming_uniform_(self.theta, a=math.sqrt(5))
        self.theta_zero = self.theta.detach().clone()
        
        # Generate projection matrix for weights
        self.proj_mat_weights = torch.empty((out_features, in_features, self.subspace_features), **factory_kwargs)
        nn.init.kaiming_uniform_(self.proj_mat_weights, a=math.sqrt(5))
        
        if bias:
            # Create and init bias, save bias zero
            self.bias = torch.empty(out_features, **factory_kwargs)
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.theta_zero)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            nn.init.uniform_(self.bias, -bound, bound)    
            self.bias_zero = self.bias.detach().clone()
            
            # Generate projection matrix for bias
            self.proj_mat_bias = torch.empty((out_features, self.subspace_features), **factory_kwargs)
            nn.init.kaiming_uniform_(self.proj_mat_bias, a=math.sqrt(5))
            
        else:
            self.register_parameter('bias', None)

    def forward(self, x):
        # in nn.Linear:
        # return F.linear(x, self.weight, self.bias)
        # torch.mm is for matrices only! torch.matmul is the one that can do broadcasting
        theta = self.theta_zero + torch.squeeze(torch.matmul(self.proj_mat_weights, self.theta_prime), dim=-1)
        bias = self.bias_zero + torch.squeeze(torch.matmul(self.proj_mat_bias, self.theta_prime), dim=-1)
        return F.linear(x, theta, bias)
    
    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, subspace_features={}, bias={}'.format(
            self.in_features, self.out_features, self.subspace_features, self.bias is not None
        )

In [12]:
intrinsic_dim = 10

In [13]:
theta_prime = Parameter(torch.empty((intrinsic_dim, 1)))
theta_prime.data.fill_(0);

In [14]:
sslin1 = SubspaceLinear(in_features=784, out_features=200, theta_prime=theta_prime)

In [15]:
for i, j in sslin1.named_parameters():
    print(i, j.shape)

theta_prime torch.Size([10, 1])


In [16]:
sslin2 = SubspaceLinear(in_features=10, out_features=5, theta_prime=theta_prime)

In [17]:
for i, j in sslin2.named_parameters():
    print(i, j.shape)

theta_prime torch.Size([10, 1])


In [18]:
class SubspaceConstrainedMNIST(nn.Module):
    def __init__(self, subspace_features: int, device="cpu"):
        """
        Paper uses 784-200-200-10
        ref: https://arxiv.org/pdf/1804.08838.pdf
        
        Ref in github:
        https://github.com/uber-research/intrinsic-dimension/blob/9754ebe1954e82973c7afe280d2c59850f281dca/intrinsic_dim/model_builders.py#L81
        """
        super().__init__()
        intrinsic_dim = 10
        self.theta_prime = Parameter(torch.empty((intrinsic_dim, 1)))
        self.theta_prime.data.fill_(0);
    
        self.hidden1 = SubspaceLinear(in_features=784, out_features=200, theta_prime=self.theta_prime, device=device)
        self.hidden2 = SubspaceLinear(in_features=200, out_features=10, theta_prime=self.theta_prime, device=device)
        
    def forward(self, x):
        x = self.hidden1(x)
        x = F.relu(x)
        x = self.hidden2(x)
        x = F.relu(x)
        x = F.log_softmax(x, dim=-1)  # (batch_size, dims)
        return x

I think this is it.

## Training

In [19]:
def train(net, num_epochs, train_loader, device="cpu"):
    opt = torch.optim.Adam(net.parameters(), lr=1e-3)
    net.train()
    loss_history = []
    acc_history = []
    
    for _ in range(num_epochs):
        for batch_id, (features, target) in enumerate(train_loader):
            # forward pass, calculate loss and backprop!
            opt.zero_grad()
            preds = net(features.to(device))
            loss = F.nll_loss(preds, target.to(device))
            loss.backward()
            loss_history.append(loss.item())
            opt.step()

            if batch_id % 100 == 0:
                print(loss.item())
                
    # Verified don't need to return the net
    return loss_history, acc_history

In [20]:
def eval(net, test_loader, device="cpu"):
    net.eval()
    test_loss = 0
    correct = 0

    for features, target in test_loader:
        output = net(features.to(device))
        test_loss += F.nll_loss(output, target.to(device)).item()
        pred = torch.argmax(output, dim=-1) # get the index of the max log-probability
        correct += pred.eq(target.to(device)).cpu().sum()

    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    accuracy = 100. * correct / len(test_loader.dataset)
    acc_history.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))
    
    return test_loss, correct.item()

In [21]:
loss_histories = {}
acc_histories = {}
test_losses = {}
corrects = {}

In [22]:
dims = [10, 50, 100, 200, 500, 600, 700, 750, 800, 1000, 1500, 2000]

In [23]:
for d in dims:
    # loss_histories_per_dim = {}
    # acc_histories_per_dim = {}
    # test_losses_per_dim = {}
    corrects_per_dim = {}
    for i in range(10):
        # Use cpu since at this size, GPU with overhead is slower than using CPU directly
        ssnet = SubspaceConstrainedMNIST(subspace_features=d, device="cpu")
        loss_history, acc_history = train(ssnet, 15, train_loader, device="cpu")
        test_loss, correct = eval(ssnet, test_loader, device="cpu")

        # Store everything
        # loss_histories_per_dim[i] = loss_history
        # acc_histories_per_dim[i] = acc_history
        # test_losses_per_dim[i] = test_loss
        corrects_per_dim[i] = correct / 10000 * 100
        
    # loss_histories[d] = loss_histories_per_dim
    # acc_histories[d] = acc_histories_per_dim
    # test_losses[d] = test_losses_per_dim
    corrects[d] = corrects_per_dim

2.3053252696990967
2.308882474899292
2.289276599884033
2.298297882080078
2.2959940433502197
2.2994325160980225
2.2925140857696533
2.296787977218628
2.287163496017456
2.2917532920837402
2.296682119369507
2.280731439590454
2.29170823097229
2.2875003814697266
2.3002285957336426
2.2807681560516357
2.282684803009033
2.290861129760742
2.2743921279907227
2.2719695568084717
2.2855136394500732
2.286449909210205
2.2929840087890625
2.2904856204986572
2.2872276306152344
2.2857015132904053
2.2785494327545166
2.3058128356933594
2.294235944747925
2.2696692943573
2.291292190551758
2.2872579097747803
2.2878148555755615
2.277557849884033
2.2964565753936768
2.2697627544403076
2.277078151702881
2.2655797004699707
2.288762331008911
2.278978109359741
2.2851271629333496
2.281541109085083
2.296551465988159
2.2647645473480225
2.3020975589752197
2.2774465084075928
2.282496690750122
2.308363914489746
2.2827672958374023
2.270899534225464
2.292191505432129
2.2825169563293457
2.2883548736572266
2.2766218185424805
2

KeyboardInterrupt: 

In [None]:
dim_scores = pd.DataFrame.from_dict(corrects)

In [None]:
dim_scores.columns = [str(i) for i in dim_scores.columns]

In [None]:
dim_scores

In [None]:
dim_scores.to_csv("mnist-scores.csv")

In [None]:
plt.plot(dim_scores.T)

This crashed at some point. I think it is due to the dense matrix mults at large sizes. Putting into script so it is easier to see the output.