## Notebook for testing EKFAC with Convolutional Layers

In [1]:
import numpy as np
import torch
import torch.nn.functional as F

In [3]:
import matplotlib.pyplot as plt
%matplotlib inline

## Testing 1D Convolutional Neural Networks

We start with a 1D CNN with a single layer.  In order to sensibly organize the derivatives with respect to the weights and biases, we try to get the input into a form such that we can write the output, $h$, as 
\begin{equation}
h_{iml} = \sum_j A_{ijl}w_{mj}
\end{equation}
Here, $w$ is the 'tensor' which represents the weights of the CNN.  The first index (denoted $m$ above) represents the the channel of the output layer, while the second index (denoted $j$ above) represents *both* the input channel and the spatial location (or offset).  So if we are used to thinking about the weight kernel as having three indices, one for output channel, one for input channel, and one for spatial location (or offset), the two-index $w_{mj}$ can be viewed as flattening the input channel + spatial location index into a single index. 

However, naturally the input is not in the correct form, but is actually organized by (batch_size, input_channel, input_spatial_location).  In order to convert this to the form above, that is, to the 'tensor' $A$, we need to organize it such that every element of the input which is 'seen' by a unique element (input channel x input spatial location) 

How do we do that?  

First, we have to look at the shape of the weights for a convolutional layer.  The weight needs to have an index for the output channel, the input channel, and the spatial location (offset).  In PyTorch, the weight tensor is organized in that order: (output_channel x input_channel x spatial_location).

In [3]:
n_input_chans = 3
n_output_chans = 2
kernel_size = 4
padding = 0
stride = 1

test_conv1d_module = torch.nn.Conv1d(in_channels=n_input_chans,
                                     out_channels=n_output_chans,
                                     kernel_size=kernel_size,
                                     padding=padding,
                                     stride=stride,
                                     bias=False)

test_conv1d_weight = test_conv1d_module.weight

In [4]:
test_conv1d_weight.shape

torch.Size([2, 3, 4])

Now we can play a trick.  To reshape the input into the form described above, with the $A_{ijl}$, we need to gather all of the elements that a particular output spatial location will 'see'.  So what we can do is make a fake convolutional weight, which pretends that the output actually has ``kernel_size * input_channels`` channels, where each of these 'channels' is simply one of the elements that the weight tensor interacts with.  Then, for each output spatial location, we simply dot-product this with the weight (as in the equation above).  

**TODO Put Picture Here**

So, we need to make a fake filter.  This is what the function below does.  Here's an example of what the fake-filter looks like if we have 3 input channels, a kernel size of 4:


[[1,0,0,0],
 [0,0,0,0],
 [0,0,0,0]],
 
[[0,1,0,0],
 [0,0,0,0],
 [0,0,0,0]],
 
[[0,0,1,0],
 [0,0,0,0],
 [0,0,0,0]],
 
[[0,0,0,1],
 [0,0,0,0],
 [0,0,0,0]],
 
[[0,0,0,0],
 [1,0,0,0],
 [0,0,0,0]],
 
[[0,0,0,0],
 [0,1,0,0],
 [0,0,0,0]],
 
[[0,0,0,0],
 [0,0,1,0],
 [0,0,0,0]],
 
[[0,0,0,0],
 [0,0,0,1],
 [0,0,0,0]],
 
[[0,0,0,0],
 [0,0,0,0],
 [1,0,0,0]],
 
[[0,0,0,0],
 [0,0,0,0],
 [0,1,0,0]],
 
[[0,0,0,0],
 [0,0,0,0],
 [0,0,1,0]],
 
[[0,0,0,0],
 [0,0,0,0],
 [0,0,0,1]],


Viewing the filter in this way suggests an easy way to generate the filter.  Make the identity matrix of shape (kernel_size * input_channels, kernel_size * input_channels), and then reshape it to be (kernel_size * input_channels, input_channels, kernel_size).

In [15]:
def get_gathering_filter(mod):
    """Convolution filter that extracts input patches."""
    dimension = len(mod.kernel_size)
    if dimension == 1:
        kernel_size, = mod.kernel_size

        g_filter = torch.eye(kernel_size*mod.in_channels, 
                             dtype=mod.weight.dtype,
                             device=mod.weight.device)
        
        return g_filter.view(kernel_size*mod.in_channels, 
                             mod.in_channels, 
                             kernel_size)     
        
    elif dimension == 2:
        kernel_width, kernel_height = mod.kernel_size
        
        g_filter = torch.eye(kernel_width*kernel_height*mod.in_channels, 
                             dtype=mod.weight.dtype,
                             device=mod.weight.device)
        
        return g_filter.view(kernel_width*kernel_height*mod.in_channels, 
                             mod.in_channels, 
                             kernel_width, 
                             kernel_height)
        
    elif dimension == 3:
        kernel_width, kernel_height, kernel_depth = mod.kernel_size
        
        g_filter = torch.eye(kernel_width*kernel_height*kernel_depth*mod.in_channels, 
                             dtype=mod.weight.dtype,
                             device=mod.weight.device)
        
        return g_filter.view(kernel_width*kernel_height*kernel_depth*mod.in_channels, 
                             mod.in_channels, 
                             kernel_width, 
                             kernel_height,
                             kernel_depth)

### Testing the gathering filters for convolutional 1D, 2D, and 3D network

In [17]:
network_dim = 1

num_tests = 100

for _ in range(num_tests):
    try:
        Nbatch, input_dim, input_channels, output_channels = np.random.randint(1, 50, size=4)

        kernel_size = np.random.randint(1, input_dim)

        padding = np.random.randint(0, 5)
        
        stride = np.random.randint(1, input_dim)

        
        conv1d_model = torch.nn.Sequential(
            torch.nn.Conv1d(in_channels=input_channels,
                            out_channels=output_channels,
                            kernel_size=kernel_size,
                            padding=padding,
                            stride=stride,
                            bias=False
                            ))

        input_1d = torch.randn(Nbatch, input_channels, input_dim)
        weight_1d = list(conv1d_model.parameters())[0]

        output_1d = conv1d_model(input_1d)
        module = list(conv1d_model.modules())[1]
        filter_1d = get_gathering_filter(module)

        A = F.conv1d(input_1d, 
                     weight=filter_1d,
                     padding=padding,
                     stride=stride)

        out_test = torch.einsum('ijl,mj->iml', A, weight_1d.view(output_channels, kernel_size * input_channels))

        diff = out_test - output_1d
        print('Result of test: {}'.format(torch.max(torch.abs(diff))))
    except:
        pass

Result of test: 2.384185791015625e-06
Result of test: 1.430511474609375e-06
Result of test: 9.5367431640625e-07
Result of test: 3.5762786865234375e-07
Result of test: 4.76837158203125e-07
Result of test: 1.430511474609375e-06
Result of test: 3.5762786865234375e-07
Result of test: 1.3113021850585938e-06
Result of test: 3.5762786865234375e-07
Result of test: 3.5762786865234375e-07
Result of test: 1.430511474609375e-06
Result of test: 1.430511474609375e-06
Result of test: 5.960464477539062e-07
Result of test: 7.152557373046875e-07
Result of test: 1.7881393432617188e-06
Result of test: 2.384185791015625e-07
Result of test: 5.960464477539062e-07
Result of test: 7.152557373046875e-07
Result of test: 5.960464477539062e-07
Result of test: 5.364418029785156e-07
Result of test: 4.76837158203125e-07
Result of test: 7.152557373046875e-07
Result of test: 1.0728836059570312e-06
Result of test: 1.1920928955078125e-06
Result of test: 4.76837158203125e-07
Result of test: 1.1920928955078125e-07
Result o

In [14]:
for i in range(1):
    Nbatch = np.random.randint(1,50)
    input_dimX = np.random.randint(1,50)
    input_dimY = np.random.randint(1,50)
    input_channels = np.random.randint(1,50)
    output_channels = np.random.randint(1,50)

    kernel_X = np.random.randint(1, input_dimX)
    kernel_Y = np.random.randint(1, input_dimY)
    kernel_size = (kernel_X, kernel_Y)

    padding = np.random.randint(0, 5)
    stride = (np.random.randint(1, input_dimX), np.random.randint(1, input_dimY))

    conv2d_model = torch.nn.Sequential(
        torch.nn.Conv2d(in_channels=input_channels,
                        out_channels=output_channels,
                        kernel_size=kernel_size,
                        padding=padding,
                        stride=stride,
                        bias=False
                        ))

    input_2d = torch.randn(Nbatch, input_channels, input_dimX, input_dimY)
    weight_2d = list(conv2d_model.parameters())[0]
    print(weight_2d.size())

    print('Test with following parameters')
    print('Input X dimension: {}'.format(input_dimX))
    print('Input Y dimension: {}'.format(input_dimY))
    print('Batch size: {}'.format(Nbatch))
    print('Input channels: {}'.format(input_channels))
    print('Output channels: {}'.format(output_channels))
    print('Kernel X size: {}'.format(kernel_X))
    print('Kernel Y size: {}'.format(kernel_Y))
    print('Padding: {}'.format(padding))
    print('Stride: {}'.format(stride))

    output_2d = conv2d_model(input_2d)
    module = list(conv2d_model.modules())[1]
    filter_2d = get_gathering_filter(module)

    A = F.conv2d(input_2d, 
                 weight=filter_2d,
                 padding=padding,
                 stride=stride,
                 groups = input_channels)
    
    print('A size: {}'.format(A.size()))
    #         print('Weight size: {}'.format(weight_2d.size()))
    #         print('Output size: {}'.format(output_2d.size()))

    out_test = torch.einsum('ijkl,mj->imkl', A, weight_2d.view(output_channels, -1))

    diff = out_test - output_2d
    print('Result of test: {}'.format(torch.max(torch.abs(diff))))



torch.Size([41, 47, 23, 15])
Test with following parameters
Input X dimension: 29
Input Y dimension: 26
Batch size: 29
Input channels: 47
Output channels: 41
Kernel X size: 23
Kernel Y size: 15
Padding: 2
Stride: (25, 5)
A size: torch.Size([29, 16215, 1, 4])
Result of test: 1.8362115621566772


In [None]:
for i in range(1):
    Nbatch = np.random.randint(1,50)
    input_dimX = np.random.randint(1,50)
    input_dimY = np.random.randint(1,50)
    input_dimZ = np.random.randint(1,50)
    input_channels = np.random.randint(1,50)
    output_channels = np.random.randint(1,50)

    kernel_X = np.random.randint(1, input_dimX)
    kernel_Y = np.random.randint(1, input_dimY)
    kernel_Z = np.random.randint(1, input_dimZ)
    kernel_size = (kernel_X, kernel_Y, kernel_Z)

    padding = np.random.randint(0, 5)
    stride = (np.random.randint(1, input_dimX), 
              np.random.randint(1, input_dimY),
              np.random.randint(1, input_dimZ),
             )

    conv3d_model = torch.nn.Sequential(
        torch.nn.Conv3d(in_channels=input_channels,
                        out_channels=output_channels,
                        kernel_size=kernel_size,
                        padding=padding,
                        stride=stride,
                        bias=False
                        ))

    input_3d = torch.randn(Nbatch, input_channels, input_dimX, input_dimY, input_dimZ)
    weight_3d = list(conv3d_model.parameters())[0]
    print(weight_3d.size())

    print('Test with following parameters')
    print('Input X dimension: {}'.format(input_dimX))
    print('Input Y dimension: {}'.format(input_dimY))
    print('Input Z dimension: {}'.format(input_dimZ))
    print('Batch size: {}'.format(Nbatch))
    print('Input channels: {}'.format(input_channels))
    print('Output channels: {}'.format(output_channels))
    print('Kernel X size: {}'.format(kernel_X))
    print('Kernel Y size: {}'.format(kernel_Y))
    print('Kernel Z size: {}'.format(kernel_Z))
    print('Padding: {}'.format(padding))
    print('Stride: {}'.format(stride))

    output_3d = conv3d_model(input_3d)
    module = list(conv3d_model.modules())[1]
    filter_3d = get_gathering_filter(module)

    A = F.conv3d(input_3d, 
                 weight=filter_3d,
                 padding=padding,
                 stride=stride)

    print('A size: {}'.format(A.size()))
    #         print('Weight size: {}'.format(weight_2d.size()))
    #         print('Output size: {}'.format(output_2d.size()))

    out_test = torch.einsum('ijklo,mj->imklo', A, weight_3d.view(output_channels, -1))

    diff = out_test - output_3d
    print('Result of test: {}'.format(torch.max(torch.abs(diff))))



torch.Size([17, 45, 16, 43, 1])
Test with following parameters
Input X dimension: 18
Input Y dimension: 46
Input Z dimension: 2
Batch size: 17
Input channels: 45
Output channels: 17
Kernel X size: 16
Kernel Y size: 43
Kernel Z size: 1
Padding: 2
Stride: (15, 40, 1)


### Failed attempt at making a sparse weight in the CNN 
Since the fake weight is so sparse, I thought it would be nice to make it a sparse Tensor in pyTorch, but somehow the sparse tensor is not working.  I guess pyTorch isn't used to using sparse tensors for CNNs, but in the future, maybe this is somethign good to do.  Here's the code:

    def get_gathering_filter(mod):

        dimension = len(mod.kernel_size)
        if dimension == 1:        
            kernel_size, = mod.kernel_size

            indices_1 = np.arange(kernel_size*mod.in_channels)
            indices_2 = indices_1 // kernel_size
            indices_3 = indices_1 % kernel_size

            nonzero_indices = torch.LongTensor(np.array([indices_1,
                                        indices_2,
                                        indices_3]))

            values = torch.FloatTensor([1]*kernel_size*mod.in_channels)
            filter_size = torch.Size([kernel_size*mod.in_channels, mod.in_channels, kernel_size])

            return torch.sparse.FloatTensor(nonzero_indices, values, filter_size)
   
This has been verified to generate the right filter (by casting it to a dense tensor), so as soon as I figure out how to put a sparse weight in PyTorch, this should be good to go.