# Torch transformation-in-training pipeline prototype
This notebook gives an overview of the prototype parameter and activation quantization-aware-training pipeline and facilities available for Torch-backed modules in Rockpool.

This is still work-in-progress and subject to change.

The torch pipeline is based on Torch's `functional_call` API, new in Torch 1.12.

## Design goals
* No need to modify pre-defined modules to make "magic quantization" modules
* General solution that can be applied widely to modules and parameters
* Convenient API for specifying transformations over parameters in a network in a "grouped" way, using Rockpool's parameter families
* Similar API for parameter- and activity-transformation
* Quantization controllable at a fine-grained level
* Provide useful and flexible transformation methods --- can be used for QAT, dropout, pruning...

In [2]:
# - Basic imports
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential, Residual

import torch

In [3]:
# - Transformation pipeline imports
import rockpool.transform.torch_transform as tt
import rockpool.utilities.tree_utils as tu

## Parameter transformations
The parameter transformation pipeline allows you to insert transformations to any parameter in the forward pass before evolution, in a configurable way. You would use this to perform quantisation-aware-training, random parameter attacks, connection pruning, ...

We'll begin here with a simple Rockpool SNN that uses most of the features of network composition in Rockpool, and is compatible with Xylo.

In [4]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec = True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

  data = tensor(data)


TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Now we build a configuration that describes the desired parameter transformation to apply to each parameter. We will transform weights with :py:func:`.stochastic_rounding` and transform biases with :py:func:`.dropout`. We can use parameter families to select the parameters to transform and which transformation to apply.

In [5]:
# - Get the 'weights' parameter family, and specify stochastic rounding
tconfig = tt.make_param_T_config(net, lambda p: tt.stochastic_rounding(p, num_levels=2**2), 'weights')
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

In [6]:
# - Now we add in the bias transformation
tu.tree_update(tconfig, tt.make_param_T_config(net, lambda p: tt.dropout(p, 0.3), 'biases'))
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

We then use this quantization configuration tree to patch the network with transformation modules, with the :py:func:`.make_param_T_network` helper function.

In [7]:
# - We now use this configuration to patch the original network with transformation modules
tnet = tt.make_param_T_network(net, tconfig)
tnet

TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Each of the transformed modules is now wrapped in a :py:class:`.TWrapper` module --- these special wrapper modules apply any required transformations to the wrapped module, in the forward pass, injecting the transformed parameters and then evolving the wrapped module as usual. The original module doesn't need to know anything special, and simply uses the quantized parameters passed to it.

The parameters are held by the original modules, un-transofmred, so that any parameters updates during training are applied to the un-transformed parameters.

If we investigate the :py:meth:`.Module.parameters` of the network we can see this structure:

In [8]:
tnet.parameters('weights')

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-1.1399e+00,  1.4406e-01,  2.5429e-01, -9.5625e-01,  5.6410e-01],
           [-8.6451e-04, -7.1668e-01, -6.2177e-02, -1.2791e+00,  1.2483e+00],
           [-9.4208e-01,  3.3010e-01,  5.8619e-01, -4.7097e-01, -3.7961e-02]],
          requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[ 0.3838,  0.9660,  0.9430,  0.4541, -0.1916],
            [ 0.9990, -0.7228,  0.7537, -0.3913,  0.3117],
            [ 0.6027,  0.7698, -1.0840, -0.3215,  0.5501],
            [ 0.8245, -0.2296, -0.1558,  0.0450,  0.2370],
            [-0.3791,  0.9442, -0.8692,  0.7645,  0.5853]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[ 0.1149, -0.6803,  0.2665,  1.0365,  0.5285],
            [ 0.5471,  0.5533,  0.8679, -0.9686, -0.9631],
            [-0.5324, -0.6320,  0.8431, -0.9942, -1.0700],
            [ 0.2207, -0.7819,  0.6502,

These are the un-transformed parameters, in floating-point format. But if we evolve the module by calling it, the parameters will all be transformed in the forward pass:

In [9]:
out, ns, rd = tnet(torch.ones(1, 10, 3))
out

tensor([[[-0., -0.,  0.],
         [-0.,  0., -0.],
         [-0.,  4., -0.],
         [-0., 10., -0.],
         [-0., 18., -0.],
         [-0., 23., -0.],
         [-0., 24., -0.],
         [-0., 22., -0.],
         [-0., 16., -0.],
         [-0.,  6., -0.]]], grad_fn=<CopySlices>)

**Training goes here!**



Here you can train the model, interacting with it as any other Rockpool :py:class:`.TorchModule`.

Once we've trained the model, you might want to access the transformed parameters. At this point you have two options:

1. you can execute the transformation such that the parameters are updated manually, using the helper function :py:func:`.apply_T`. This will "burn in" the transformation, storing the result as the "real" parameters within the module:

In [18]:
ttnet = tt.apply_T(tnet)
ttnet

TorchSequential  with shape (3, 3) {
    ActWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        ActWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    ActWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

If we now examine the parameters, we will see the low-resolution quantized versions (still stored as floating-point numbers -- this transformation did not force the parameters to be integers).

In [19]:
ttnet.parameters('weights')

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-1.2824,  1.1413,  0.0877,  0.0578, -1.0581],
           [ 0.5104,  1.1349, -1.3793, -0.1282,  0.2503],
           [-0.3207, -0.3518,  0.9708, -0.0555, -1.1052]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[-0.0842,  0.2129, -0.2631, -0.6283,  0.1849],
            [-0.4626, -0.7207,  0.1398,  0.2352, -0.2340],
            [-0.0220, -0.5523,  0.2323, -0.2215, -0.1106],
            [-0.9892,  0.2199, -0.2059,  0.4504, -0.6871],
            [-0.5974, -0.9785,  0.1854, -0.0626, -0.7767]], requires_grad=True)}},
  '1_LIFTorch': {'w_rec': Parameter containing:
   tensor([[ 0.6062, -0.9756, -0.3870, -0.6646,  0.0906],
           [-0.7449, -0.8844,  0.0129, -0.9246,  0.9288],
           [ 0.0158,  0.1328,  0.6948,  0.9710, -0.2518],
           [ 0.3792, -0.4621,  0.8133, -0.4422, -0.1458],
           [-0.3411,  0.4592, -0.0581,  0.3192,  1.0715]], requ

You can now convert the network back to the original "unpatched" structure with the helper function  :py:func:`.remove_T_net`.

In [12]:
unpatched_net = tt.remove_T_net(ttnet)
unpatched_net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Compare this with the original network above.

2. The second option is to "unpatch" the network with :py:func:`.remove_T_net` and use post-training quantisation through whatever method you prefer. This might be preferable if you have included "destructive" transformations such as :py:func:`.dropout`.

### How to: Quantize to round numbers
We might want to quantize to integer levels, for example when targetting processors that use integer logic and representations for parameters (such as Xylo). This is possible with :py:func:`.stochastic_rounding`.

The cell below shows you how to use :py:func:`.stochastic_rounding` to target signed integer parameter values. By default, :py:func:`.stochastic_rounding` makes sure that zero in the input space maps to a zero in the output space.

In [30]:
w = torch.rand((5, 5)) - 0.5

num_bits = 4

tt.stochastic_rounding(
    w,
    output_range=[-2**(num_bits-1)+1, 2**(num_bits-1)],
    num_levels=2**num_bits,
)

tensor([[-4., -5., -1.,  6.,  6.],
        [ 1., -1., -3.,  3.,  1.],
        [ 3., -5.,  4., -3.,  8.],
        [-3.,  7.,  5.,  3.,  6.],
        [ 5., -7.,  4., -7.,  1.]])

## Activity transformations
There is a similar pipeline available for activity transformations. This can be used to transform the output of modules in the forward pass, without modifying the module code.

Let's begin again with a simple SNN artchitecture:

In [31]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec = True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

  data = tensor(data)


TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

We need to build a configuration to patch the network with. We can conveniently specify which modules to transform according to the module class. Here we'll perform rounding of output activations to 8-bit signed integers, using the function :py:func:`.deterministic_rounding`.

In [14]:
# - Build a null configuration tree, which can be manipulated directly
tt.make_act_T_config(net)

# - Specify a transformation function as a lambda
T_fn = lambda p: tt.deterministic_rounding(p, output_range = [-128, 127], num_levels = 2**8)

# - Conveniently build a configuration tree by selecting a module class
tconf = tt.make_act_T_config(net, T_fn, LinearTorch)
tconf

{'': None,
 '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '1_LIFTorch': {'': None},
 '2_TorchResidual': {'': None,
  '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'': None}},
 '3_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '4_LIFTorch': {'': None}}

Now we patch the network, analogously to the parameter transformation above, using the helper function :py:func:`.make_act_T_network`:

In [15]:
# - Make a transformed network by patching with the configuration
tnet = tt.make_act_T_network(net, tconf)
tnet

TorchSequential  with shape (3, 3) {
    ActWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        ActWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    ActWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

Again, the network has been patched (this time with :py:class:`.ActWrapper` modules), each of which handle the transformations for a single wrapped module.

Now we evolve the module as useful, and check the outputs of the :py:clas:`LinearTorch` layers:

In [32]:
# - We evolve the module as usual
out, ns, rd = tnet(torch.ones(1, 10, 3), record = True)

In [33]:
# - Examine the recorded outputs from the network; the LinearTorch layers have quantised output
rd

{'0_LinearTorch': {},
 '0_LinearTorch_output': tensor([[[ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.],
          [ -73.,  127.,  -22.,   -9., -127.]]], requires_grad=True),
 '1_LIFTorch': {'vmem': tensor([[[-3.3144e+03,  5.1770e-01, -9.9886e+02, -4.0862e+02, -5.7661e+03],
           [-3.7952e+03,  9.9487e-02, -1.1437e+03, -4.6790e+02, -6.6025e+03],
           [-4.2906e+03,  1.4648e-03, -1.2931e+03, -5.2897e+02, -7.4644e+03],
           [-4.7981e+03,  9.7461e-01, -1.4460e+03, -5.9154e+02, -8.3474e+03],
           [-5.3153e+03,  8.9099e-01, -1.6019e+03, -6.5531e+02, -9.2472e+03],
           [-5.8401e+03,  8.7622e-01, -1.7600e+03

As expected, the outputs of the Linear layers are now signed 8-bit integers, maintained as floating-point representation.