# Torch transformation-in-training pipeline prototype
This notebook gives an overview of the prototype parameter and activation quantization-aware-training pipeline and facilities available for Torch-backed modules in Rockpool.

This is still work-in-progress and subject to change.

The torch pipeline is based on Torch's `functional_call` API, new in Torch 1.12.

## Design goals
* No need to modify pre-defined modules to make "magic quantization" modules
* General solution that can be applied widely to modules and parameters
* Convenient API for specifying transformations over parameters in a network in a "grouped" way, using Rockpool's parameter families
* Similar API for parameter- and activity-transformation
* Quantization controllable at a fine-grained level
* Provide useful and flexible transformation methods --- can be used for QAT, dropout, pruning...

In [1]:
# - Basic imports
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential, Residual

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# - Transformation pipeline imports
import rockpool.transform.torch_transform as tt
import rockpool.utilities.tree_utils as tu

## Parameter transformations
The parameter transformation pipeline allows you to insert transformations to any parameter in the forward pass before evolution, in a configurable way. You would use this to perform quantisation-aware-training, random parameter attacks, connection pruning, ...

We'll begin here with a simple Rockpool SNN that uses most of the features of network composition in Rockpool, and is compatible with Xylo.

In [None]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

In [None]:
# - Get the 'weights' parameter family, and specify stochastic rounding
tconfig = tt.make_param_T_config(
    net, lambda p: tt.stochastic_rounding(p, num_levels=2**2), "weights"
)
tconfig

In [None]:
# - Now we add in the bias transformation
tu.tree_update(
    tconfig, tt.make_param_T_config(net, lambda p: tt.dropout(p, 0.3), "biases")
)
tconfig

In [None]:
# - We now use this configuration to patch the original network with transformation modules
tnet = tt.make_param_T_network(net, tconfig)
tnet

In [None]:
tnet.parameters("weights")

These are the un-transformed parameters, in floating-point format. But if we evolve the module by calling it, the parameters will all be transformed in the forward pass:

In [None]:
out, ns, rd = tnet(torch.ones(1, 10, 3))
out

**Training goes here!**


In [None]:
ttnet = tt.apply_T(tnet, inplace=True)
ttnet

If we now examine the parameters, we will see the low-resolution quantized versions (still stored as floating-point numbers -- this transformation did not force the parameters to be integers).

In [None]:
ttnet.parameters("weights")

In [None]:
unpatched_net = tt.remove_T_net(ttnet, inplace=True)
unpatched_net

### How to: Quantize to round numbers

In [None]:
w = torch.rand((5, 5)) - 0.5

num_bits = 4

tt.stochastic_rounding(
    w,
    output_range=[-(2 ** (num_bits - 1)) + 1, 2 ** (num_bits - 1)],
    num_levels=2**num_bits,
)

## Activity transformations
There is a similar pipeline available for activity transformations. This can be used to transform the output of modules in the forward pass, without modifying the module code.

Let's begin again with a simple SNN artchitecture:

In [None]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

In [None]:
# - Build a null configuration tree, which can be manipulated directly
tt.make_act_T_config(net)

# - Specify a transformation function as a lambda
T_fn = lambda p: tt.deterministic_rounding(
    p, output_range=[-128, 127], num_levels=2**8
)

# - Conveniently build a configuration tree by selecting a module class
tconf = tt.make_act_T_config(net, T_fn, LinearTorch)
tconf

In [None]:
# - Make a transformed network by patching with the configuration
tnet = tt.make_act_T_network(net, tconf)
tnet

In [None]:
# - We evolve the module as usual
out, ns, rd = tnet(torch.ones(1, 10, 3), record=True)

In [None]:
# - Examine the recorded outputs from the network; the LinearTorch layers have quantised output
rd

As expected, the outputs of the Linear layers are now signed 8-bit integers, maintained as floating-point representation.

## Decay transformations
In case of training decays, decay parameter of LIF neurons $\exp{(-dt/\tau)}$ can be quantized to match the 
way that decay is implemented in Xylo:

 bitshift subtraction:
 $V_{mem}~~ \rightarrow V_{mem} \cdot (1- \frac{1}{2^N})$ 

In [3]:
# - Build a network to use
# activate the decay training for the last layer
net_decay = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, decay_training=True),
)
net_decay

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [5]:
tconfig_decay = tt.make_param_T_config(net_decay, lambda p: tt.t_decay(p), "decays")
print(tconfig_decay["4_LIFTorch"])

{'alpha': <function <lambda> at 0x7fae52552790>, 'beta': <function <lambda> at 0x7fae52552790>}
TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}


In [6]:
t_net_decay = tt.make_param_T_network(net_decay, tconfig_decay)
print(t_net_decay)

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}


## Building a network with bitshift decays 
in case of passing BitShift_training=True to the LIF neurons the membrane and synaptic decays will be directly applied based on bitshift subtraction. For quantization its enough to round them. 


In [8]:
# - Build a network to use
# activate the decay training for the last layer
net_bitshift = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, BitShift_training=True),
)
net_bitshift

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [13]:
tconfig_bitshift = tt.make_param_T_config(
    net_bitshift, lambda p: tt.round_passthrough(p), "bitshifts"
)
print(tconfig_bitshift["4_LIFTorch"])

{'dash_mem': <function <lambda> at 0x7fadbe1cf280>, 'dash_syn': <function <lambda> at 0x7fadbe1cf280>}
