# Conversion From a LIF network to DynapSim Network

In [1]:
# - Rockpool imports
from rockpool.nn.modules import LinearJax
from rockpool.nn.modules.jax.jax_lif_ode import LIFODEJax
from rockpool.nn.combinators import Sequential

# - Numpy
import numpy as np
import seaborn as sns
from rockpool.devices.dynapse import (
    mapper,
    autoencoder_quantization,
    config_from_specification,
    dynapsim_net_from_config,
    dynapsim_net_from_spec,
)

# - Plotting imports and config
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.figsize'] = [12, 4]
plt.rcParams['figure.dpi'] = 300



In [2]:
# - Network size
Nin = 12
N1 = 20
N2 = 8
N3 = 4
input_scale = 20.
dt = 1e-3

In [3]:
# - Generate a network using the sequential combinator
modFFwd = Sequential(
    LinearJax((Nin, N1), has_bias = False), # 30x20
    LIFODEJax((N1, N1), dt = dt, has_rec=True), # 40x40
    LinearJax((N1, N2), has_bias = False), # 30x20
    LIFODEJax((N2, N2), dt = dt, has_rec=True), # 40x40
    LinearJax((N2, N3), has_bias = False), # 30x20
    LIFODEJax((N3, N3), dt = dt, has_rec=True), # 40x40
)

print(modFFwd)



JaxSequential  with shape (12, 4) {
    LinearJax '0_LinearJax' with shape (12, 20)
    LIFODEJax '1_LIFODEJax' with shape (20, 20)
    LinearJax '2_LinearJax' with shape (20, 8)
    LIFODEJax '3_LIFODEJax' with shape (8, 8)
    LinearJax '4_LinearJax' with shape (8, 4)
    LIFODEJax '5_LIFODEJax' with shape (4, 4)
}


In [4]:
spec = mapper(modFFwd.as_graph())
spec

{'mapped_graph': GraphHolder "JaxSequential__11523041616_transformed_SE_11523161632" with 12 input nodes -> 32 output nodes,
 'weights_in': array([[ 0.29121522,  0.51561257,  0.67196828, -0.45759972,  0.49078786,
         -0.64449581, -0.44781407, -0.20569105, -0.25809057,  0.66788125,
          0.01850795, -0.5465905 ,  0.69127915,  0.12291428, -0.20380376,
         -0.54291643,  0.27950044, -0.66336056,  0.3397322 , -0.13787086,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0.        ,  0.        ],
        [-0.6283843 ,  0.16730094, -0.57627076, -0.40209394,  0.42936083,
         -0.22841963, -0.44805747, -0.1655184 ,  0.16999162,  0.33310825,
          0.02883294, -0.43060643,  0.13300746, -0.64090356, -0.17115936,
         -0.16211612,  0.24803549,  0.25094584,  0.58060647,  0.28224586,
          0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
          0

In [5]:
mod_se2_post_clustering = dynapsim_net_from_spec(**spec)

In [6]:
spec.update(autoencoder_quantization(**spec))
spec


{'mapped_graph': GraphHolder "JaxSequential__11523041616_transformed_SE_11523161632" with 12 input nodes -> 32 output nodes,
 'weights_in': [array([[ 6, 15, 15, 13, 13, 15, 13,  5,  1, 15,  1, 15, 15,  0,  8, 15,
           3, 15, 12,  4,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [15,  4, 15,  7, 14,  9,  7,  1,  1, 10,  0, 14,  8, 15,  3,  8,
           3,  2,  7,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 2,  8,  2, 13,  4, 14,  8,  0,  0,  2,  0, 15, 12,  0, 13, 15,
          12, 15, 15, 13,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [ 9,  7,  2,  2,  2,  7, 14,  3,  3,  1,  7, 15, 13, 15,  4,  9,
           2, 11, 15,  8,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [15, 13,  8, 12, 12,  1, 15, 15, 12,  0, 15, 15,  3, 15, 14, 12,
           2,  7, 10,  9,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
         [11, 15,  0,  0, 15, 11,  5,  0,  3,  4, 15,  8, 14,  0, 15,  5,
          14,  8,  5, 15,  0,  0,  0,  0, 

In [7]:
config, input_channel_map = config_from_specification(**spec)

In [8]:
mod_se2_reconstructed = dynapsim_net_from_config(config, input_channel_map)

In [9]:
mod_se2_post_clustering[0].weight

array([[ 0.29121522,  0.51561257,  0.67196828, -0.45759972,  0.49078786,
        -0.64449581, -0.44781407, -0.20569105, -0.25809057,  0.66788125,
         0.01850795, -0.5465905 ,  0.69127915,  0.12291428, -0.20380376,
        -0.54291643,  0.27950044, -0.66336056,  0.3397322 , -0.13787086,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ],
       [-0.6283843 ,  0.16730094, -0.57627076, -0.40209394,  0.42936083,
        -0.22841963, -0.44805747, -0.1655184 ,  0.16999162,  0.33310825,
         0.02883294, -0.43060643,  0.13300746, -0.64090356, -0.17115936,
        -0.16211612,  0.24803549,  0.25094584,  0.58060647,  0.28224586,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ],
       [-0.24831037, -0.07157614, -0.14221636,  0.3726

In [10]:
mod_se2_reconstructed[0].weight

array([[ 0.31539187,  0.61595609,  0.61595609, -0.4600281 ,  0.4600281 ,
        -0.61595609, -0.4600281 , -0.30410011, -0.14463623,  0.61595609,
         0.14463623, -0.61595609,  0.61595609,  0.        , -0.15592799,
        -0.61595609,  0.30056422, -0.61595609,  0.31539187, -0.15946388,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ],
       [-0.61595609,  0.15946388, -0.61595609, -0.4600281 ,  0.47131986,
        -0.30056422, -0.4600281 , -0.14463623,  0.14463623,  0.31185598,
         0.        , -0.47131986,  0.15592799, -0.61595609, -0.30056422,
        -0.15592799,  0.30056422,  0.15592799,  0.4600281 ,  0.30056422,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ],
       [-0.15592799, -0.15592799, -0.15592799,  0.4600