# Torch transformation-in-training pipeline prototype
This notebook gives an overview of the prototype parameter and activation quantization-aware-training pipeline and facilities available for Torch-backed modules in Rockpool.

This is still work-in-progress and subject to change.

The torch pipeline is based on Torch's `functional_call` API, new in Torch 1.12.

## Design goals
* No need to modify pre-defined modules to make "magic quantization" modules
* General solution that can be applied widely to modules and parameters
* Convenient API for specifying transformations over parameters in a network in a "grouped" way, using Rockpool's parameter families
* Similar API for parameter- and activity-transformation
* Quantization controllable at a fine-grained level
* Provide useful and flexible transformation methods --- can be used for QAT, dropout, pruning...

In [1]:
# - Basic imports
from rockpool.nn.modules import LinearTorch, LIFTorch
from rockpool.nn.combinators import Sequential, Residual

import torch

In [2]:
# - Transformation pipeline imports
import rockpool.transform.torch_transform as tt
import rockpool.utilities.tree_utils as tu

## Parameter transformations
The parameter transformation pipeline allows you to insert transformations to any parameter in the forward pass before evolution, in a configurable way. You would use this to perform quantisation-aware-training, random parameter attacks, connection pruning, ...

We'll begin here with a simple Rockpool SNN that uses most of the features of network composition in Rockpool, and is compatible with Xylo.

In [3]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [4]:
# - Get the 'weights' parameter family, and specify stochastic rounding
tconfig = tt.make_param_T_config(
    net, lambda p: tt.stochastic_rounding(p, num_levels=2**2), "weights"
)
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

In [5]:
# - Now we add in the bias transformation
tconfig = tu.tree_update(
    tconfig, tt.make_param_T_config(net, lambda p: tt.dropout(p, 0.3), "biases")
)
tconfig

{'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
 '2_TorchResidual': {'0_LinearTorch': {'weight': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'w_rec': <function __main__.<lambda>(p)>}},
 '3_LinearTorch': {'weight': <function __main__.<lambda>(p)>}}

In [6]:
# - We now use this configuration to patch the original network with transformation modules
tnet = tt.make_param_T_network(net, tconfig)
tnet

TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [7]:
tnet.parameters("weights")

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-0.8337, -0.4811, -1.0898, -0.7309, -0.2193],
           [ 0.8017,  0.7001, -0.1010, -1.2180,  1.4088],
           [-1.3568, -0.6440,  0.1773, -1.0603,  0.7051]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[ 0.2452, -0.1040, -0.6809, -0.5192, -1.0748],
            [-0.7669, -0.2566,  0.6481, -0.1794, -0.1849],
            [ 0.5732, -0.7922, -0.7750,  0.7075, -0.5805],
            [ 0.5087, -0.1379, -0.3209,  0.3345,  0.2006],
            [ 0.9543, -1.0458,  0.9624, -0.3144,  0.1960]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[-0.5553,  0.5043, -0.8002,  1.0671,  0.5552],
            [ 0.0891, -1.0943, -0.3825, -0.7623,  0.8179],
            [ 0.3327,  0.8391,  0.6506,  0.1105, -0.1883],
            [-0.0614, -0.9433, -0.6902, -0.4699,  0.3375],
            [-0.4792,  0.7939,  0.0106,  0.4522, -

These are the un-transformed parameters, in floating-point format. But if we evolve the module by calling it, the parameters will all be transformed in the forward pass:

In [8]:
out, ns, rd = tnet(torch.ones(1, 10, 3))
out



tensor([[[  0.,   2.,   0.],
         [  2.,   8.,   0.],
         [  0.,  17.,   0.],
         [  0.,  32.,   0.],
         [  0.,  47.,   0.],
         [  0.,  65.,   0.],
         [  0.,  89.,   4.],
         [  0., 117.,  27.],
         [  0., 154.,  46.],
         [  0., 200.,  70.]]], grad_fn=<CopySlices>)

**Training goes here!**


In [9]:
ttnet = tt.apply_T(tnet, inplace=True)
ttnet

TorchSequential  with shape (3, 3) {
    TWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        TWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        TWrapper '1_LIFTorch' with shape (5, 5) {
            LIFTorch '_mod' with shape (5, 5)
        }
    }
    TWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

If we now examine the parameters, we will see the low-resolution quantized versions (still stored as floating-point numbers -- this transformation did not force the parameters to be integers).

In [10]:
ttnet.parameters("weights")

{'0_LinearTorch': {'_mod': {'weight': Parameter containing:
   tensor([[-1.4088, -0.4696, -1.4088, -0.4696,  0.4696],
           [ 1.4088,  1.4088,  0.4696, -1.4088,  1.4088],
           [-1.4088, -0.4696, -0.4696, -1.4088,  1.4088]], requires_grad=True)}},
 '2_TorchResidual': {'0_LinearTorch': {'_mod': {'weight': Parameter containing:
    tensor([[ 0.3583, -0.3583, -0.3583, -1.0748, -1.0748],
            [-1.0748, -0.3583,  0.3583, -0.3583, -0.3583],
            [ 0.3583, -0.3583, -0.3583,  0.3583, -0.3583],
            [ 0.3583, -0.3583, -0.3583,  0.3583,  0.3583],
            [ 1.0748, -1.0748,  1.0748, -0.3583,  0.3583]], requires_grad=True)}},
  '1_LIFTorch': {'_mod': {'w_rec': Parameter containing:
    tensor([[-1.0943,  0.3648, -1.0943,  1.0943,  0.3648],
            [ 0.3648, -1.0943, -0.3648, -1.0943,  0.3648],
            [ 0.3648,  0.3648,  0.3648,  0.3648, -0.3648],
            [ 0.3648, -1.0943, -0.3648, -0.3648,  0.3648],
            [-0.3648,  1.0943,  0.3648,  0.3648, -

In [11]:
unpatched_net = tt.remove_T_net(ttnet, inplace=True)
unpatched_net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

### How to: Quantize to round numbers

In [12]:
w = torch.rand((5, 5)) - 0.5

num_bits = 4

tt.stochastic_rounding(
    w,
    output_range=[-(2 ** (num_bits - 1)) + 1, 2 ** (num_bits - 1)],
    num_levels=2**num_bits,
)

tensor([[ 1.,  1.,  2., -3., -7.],
        [-5., -2., -5.,  7., -1.],
        [-2.,  4., -6.,  4., -5.],
        [-7.,  7.,  6., -1., -6.],
        [-2.,  4.,  4., -5.,  2.]])

## Activity transformations
There is a similar pipeline available for activity transformations. This can be used to transform the output of modules in the forward pass, without modifying the module code.

Let's begin again with a simple SNN artchitecture:

In [13]:
# - Build a network to use
net = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3),
)
net

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [14]:
# - Build a null configuration tree, which can be manipulated directly
tt.make_act_T_config(net)

# - Specify a transformation function as a lambda
T_fn = lambda p: tt.deterministic_rounding(
    p, output_range=[-128, 127], num_levels=2**8
)

# - Conveniently build a configuration tree by selecting a module class
tconf = tt.make_act_T_config(net, T_fn, LinearTorch)
tconf

{'': None,
 '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '1_LIFTorch': {'': None},
 '2_TorchResidual': {'': None,
  '0_LinearTorch': {'': <function __main__.<lambda>(p)>},
  '1_LIFTorch': {'': None}},
 '3_LinearTorch': {'': <function __main__.<lambda>(p)>},
 '4_LIFTorch': {'': None}}

In [15]:
# - Make a transformed network by patching with the configuration
tnet = tt.make_act_T_network(net, tconf)
tnet

TorchSequential  with shape (3, 3) {
    ActWrapper '0_LinearTorch' with shape (3, 5) {
        LinearTorch '_mod' with shape (3, 5)
    }
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        ActWrapper '0_LinearTorch' with shape (5, 5) {
            LinearTorch '_mod' with shape (5, 5)
        }
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    ActWrapper '3_LinearTorch' with shape (5, 3) {
        LinearTorch '_mod' with shape (5, 3)
    }
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [16]:
# - We evolve the module as usual
out, ns, rd = tnet(torch.ones(1, 10, 3), record=True)

In [17]:
# - Examine the recorded outputs from the network; the LinearTorch layers have quantised output
rd

{'0_LinearTorch': {},
 '0_LinearTorch_output': tensor([[[  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.],
          [  -2.,   68., -105., -128.,   18.]]], requires_grad=True),
 '1_LIFTorch': {'vmem': tensor([[[-1.9025e+00,  6.8360e-01, -9.9879e+01, -1.2176e+02,  1.2213e-01],
           [-5.5218e+00,  8.6280e-01, -2.8989e+02, -3.5340e+02,  5.2538e-01],
           [-1.0686e+01,  5.6142e-01, -5.6102e+02, -6.8391e+02,  4.0171e-01],
           [-1.7236e+01,  9.4843e-01, -9.0489e+02, -1.1031e+03,  2.1217e-02],
           [-2.5024e+01,  2.7499e-01, -1.3138e+03, -1.6015e+03,  6.7770e-01],
           [-3.3914e+01,  1.0071e-02, -1.7805e+03

As expected, the outputs of the Linear layers are now signed 8-bit integers, maintained as floating-point representation.

## Decay transformations
In case of training decays, decay parameter of LIF neurons $\exp{(-dt/\tau)}$ can be quantized to match the 
way that decay is implemented in Xylo:

 bitshift subtraction:
 $V_{mem}~~ \rightarrow V_{mem} \cdot (1- \frac{1}{2^N})$ 

In [20]:
# - Build a network to use
# activate the decay training for the last layer
net_decay = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, leak_mode="decays"),
)
net_decay

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [21]:
tconfig_decay = tt.make_param_T_config(net_decay, lambda p: tt.t_decay(p), "decays")
print(tconfig_decay["4_LIFTorch"])

{'alpha': <function <lambda> at 0x7f8ed3137ca0>, 'beta': <function <lambda> at 0x7f8ed3137ca0>}


In [22]:
t_net_decay = tt.make_param_T_network(net_decay, tconfig_decay)
print(t_net_decay)

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}


## Building a network with bitshift decays 
in case of passing BitShift_training=True to the LIF neurons the membrane and synaptic decays will be directly applied based on bitshift subtraction. For quantization its enough to round them. 


In [25]:
# - Build a network to use
# activate the decay training for the last layer
net_bitshift = Sequential(
    LinearTorch((3, 5)),
    LIFTorch(5),
    Residual(
        LinearTorch((5, 5)),
        LIFTorch(5, has_rec=True),
    ),
    LinearTorch((5, 3)),
    LIFTorch(3, leak_mode="bitshifts"),
)
net_bitshift

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    LIFTorch '4_LIFTorch' with shape (3, 3)
}

In [26]:
tconfig_bitshift = tt.make_param_T_config(
    net_bitshift, lambda p: tt.round_passthrough(p), "bitshifts"
)
print(tconfig_bitshift["4_LIFTorch"])

{'dash_mem': <function <lambda> at 0x7f8ed3063700>, 'dash_syn': <function <lambda> at 0x7f8ed3063700>}


In [27]:
t_net_bitshift = tt.make_param_T_network(net_bitshift, tconfig_bitshift)
print(t_net_bitshift)

TorchSequential  with shape (3, 3) {
    LinearTorch '0_LinearTorch' with shape (3, 5)
    LIFTorch '1_LIFTorch' with shape (5, 5)
    TorchResidual '2_TorchResidual' with shape (5, 5) {
        LinearTorch '0_LinearTorch' with shape (5, 5)
        LIFTorch '1_LIFTorch' with shape (5, 5)
    }
    LinearTorch '3_LinearTorch' with shape (5, 3)
    TWrapper '4_LIFTorch' with shape (3, 3) {
        LIFTorch '_mod' with shape (3, 3)
    }
}
