# Torch transformation-in-training pipeline prototype
This notebook gives an overview of the prototype parameter and activation quantization-aware-training pipeline and facilities available for Torch-backed modules in Rockpool.

This is still work-in-progress and subject to change.

The torch pipeline is based on Torch's `functional_call` API, new in Torch 1.12.

## Design goals
* No need to modify pre-defined modules to make "magic quantization" modules
* General solution that can be applied widely to modules and parameters
* Convenient API for specifying transformations over parameters in a network in a "grouped" way, using Rockpool's parameter families
* Similar API for parameter- and activity-transformation
* Quantization controllable at a fine-grained level
* Provide useful and flexible transformation methods --- can be used for QAT, dropout, pruning...

In [1]:
# - Basic imports
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential, Residual

import torch

In [2]:
# - Transformation pipeline imports
import rockpool.transform.torch_transform as tt
import rockpool.utilities.tree_utils as tu

## Parameter transformations
The parameter transformation pipeline allows you to insert transformations to any parameter in the forward pass before evolution, in a configurable way. You would use this to perform quantisation-aware-training, random parameter attacks, connection pruning, ...

We'll begin here with a simple Rockpool SNN that uses most of the features of network composition in Rockpool, and is compatible with Xylo.

In [3]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec = True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [4]:
# - Get the 'weights' parameter family, and specify stochastic rounding
tconfig = tt.make_param_T_config(net, lambda p: tt.stochastic_rounding(p, num_levels=2**2), 'weights')
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

In [5]:
# - Now we add in the bias transformation
tu.tree_update(tconfig, tt.make_param_T_config(net, lambda p: tt.dropout(p, 0.3), 'biases'))
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

In [6]:
# - We now use this configuration to patch the original network with transformation modules
tnet = tt.make_param_T_network(net, tconfig)
tnet

TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [7]:
tnet.parameters('weights')

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[ 0.5187, -0.9304,  0.0414,  0.3679, -1.2738],
           [ 0.3060, -1.1478,  0.1919,  1.0421,  0.5358],
           [-1.2082, -0.5775,  0.7186, -0.1107, -0.8000]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[-0.1623, -1.0501, -0.6565,  0.2511, -0.2790],
            [-0.7212,  0.1158,  0.1436, -1.0457, -0.8841],
            [ 0.1946,  0.1269,  0.1576,  0.8541, -0.8278],
            [-0.2922,  0.6165, -0.4600, -0.2130, -1.0764],
            [-0.7831, -0.3994, -1.0361, -1.0295, -0.1791]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[ 0.6341, -0.3931, -0.6714,  0.1988,  0.0611],
            [ 1.0097,  0.7559,  0.5021,  0.0831, -0.2205],
            [-0.2008,  0.4110, -0.8212, -0.0223,  0.0696],
            [ 0.2445, -0.7481, -1.0580, -0.9971,  0.1447],
            [ 0.1631,  0.5020, -0.0158,  0.7330,  

These are the un-transformed parameters, in floating-point format. But if we evolve the module by calling it, the parameters will all be transformed in the forward pass:

In [8]:
out, ns, rd = tnet(torch.ones(1, 10, 3))
out

tensor([[[-0., 0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [-0., -0., -0.],
         [3., -0., -0.],
         [7., -0., -0.]]], grad_fn=<CopySlices>)

**Training goes here!**


In [9]:
ttnet = tt.apply_T(tnet, inplace=True)
ttnet

TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

If we now examine the parameters, we will see the low-resolution quantized versions (still stored as floating-point numbers -- this transformation did not force the parameters to be integers).

In [10]:
ttnet.parameters('weights')

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[ 1.2738, -0.4246, -0.4246,  0.4246, -1.2738],
           [ 0.4246, -1.2738,  0.4246,  1.2738,  0.4246],
           [-0.4246, -0.4246,  0.4246,  0.4246, -0.4246]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[-0.3588, -1.0764, -1.0764, -0.3588, -0.3588],
            [-1.0764, -0.3588,  0.3588, -1.0764, -1.0764],
            [ 0.3588, -0.3588,  0.3588,  0.3588, -1.0764],
            [-0.3588,  1.0764, -0.3588, -0.3588, -1.0764],
            [-1.0764, -0.3588, -1.0764, -1.0764, -0.3588]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[ 0.3527, -0.3527, -0.3527,  0.3527, -0.3527],
            [ 1.0580,  0.3527,  0.3527,  0.3527, -0.3527],
            [ 0.3527,  0.3527, -1.0580, -0.3527, -0.3527],
            [ 0.3527, -1.0580, -1.0580, -1.0580, -0.3527],
            [ 0.3527,  0.3527, -0.3527,  1.0580,  

In [12]:
unpatched_net = tt.remove_T_net(ttnet, inplace = True)
unpatched_net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

### How to: Quantize to round numbers

In [13]:
w = torch.rand((5, 5)) - 0.5

num_bits = 4

tt.stochastic_rounding(
    w,
    output_range=[-2**(num_bits-1)+1, 2**(num_bits-1)],
    num_levels=2**num_bits,
)

tensor([[ 7.,  6.,  5.,  8., -4.],
        [-6.,  5., -7.,  1.,  5.],
        [ 1.,  3., -3., -2., -2.],
        [ 3.,  6., -2., -3.,  0.],
        [ 8.,  1.,  6.,  4., -1.]])

## Activity transformations
There is a similar pipeline available for activity transformations. This can be used to transform the output of modules in the forward pass, without modifying the module code.

Let's begin again with a simple SNN artchitecture:

In [14]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec = True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [15]:
# - Build a null configuration tree, which can be manipulated directly
tt.make_act_T_config(net)

# - Specify a transformation function as a lambda
T_fn = lambda p: tt.deterministic_rounding(p, output_range = [-128, 127], num_levels = 2**8)

# - Conveniently build a configuration tree by selecting a module class
tconf = tt.make_act_T_config(net, T_fn, LinearTorch)
tconf

{'': None,
 '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '1_LIFTorch': {'': None},
 '2_TorchResidual': {'': None,
  '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'': None}},
 '3_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '4_LIFTorch': {'': None}}

In [16]:
# - Make a transformed network by patching with the configuration
tnet = tt.make_act_T_network(net, tconf)
tnet

TorchSequential  with shape (3, 3) {
    ActWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        ActWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    ActWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [17]:
# - We evolve the module as usual
out, ns, rd = tnet(torch.ones(1, 10, 3), record = True)

In [18]:
# - Examine the recorded outputs from the network; the LinearTorch layers have quantised output
rd

{'0_LinearTorch': {},
 '0_LinearTorch_output': tensor([[[-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.],
          [-128.,   94.,  -59.,  124.,  -94.]]], requires_grad=True),
 '1_LIFTorch': {'vmem': tensor([[[-1.2176e+02,  4.1557e-01, -5.6123e+01,  9.5245e-01, -8.9416e+01],
           [-3.5340e+02,  8.6557e-01, -1.6289e+02,  5.8304e-02, -2.5952e+02],
           [-6.8391e+02,  2.0020e-01, -3.1524e+02,  9.3555e-01, -5.0224e+02],
           [-1.1031e+03,  5.2795e-01, -5.0846e+02,  2.9263e-01, -8.1009e+02],
           [-1.6015e+03,  4.6997e-02, -7.3821e+02,  2.5238e-01, -1.1761e+03],
           [-2.1705e+03,  2.2644e-01, -1.0005e+03

As expected, the outputs of the Linear layers are now signed 8-bit integers, maintained as floating-point representation.