In [1]:
from rockpool.devices.dynapse import *
import jax.numpy as jnp
import numpy as np

Could not import package: No module named 'iaf_nest'


In [2]:
CAM = np.array(
    [
        [
            [0, 0, 0, 2],
            [0, 0, 1, 0],
            [0, 0, 0, 2],
            [0, 0, 0, 0],
            [0, 0, 0, 1],
        ],
        [
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [1, 0, 0, 0],
            [0, 0, 1, 0],
        ],
        [
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 0],
            [0, 0, 0, 1],
        ],
    ],
    dtype=np.float32,
)

In [3]:
bitmask = np.array(
    [
        [
            [14, 4, 3, 7],
            [7, 2, 7, 2],
            [10, 14, 8, 13],
            [10, 9, 1, 8],
            [4, 15, 11, 7],
        ],
        [
            [9, 11, 10, 13],
            [12, 7, 6, 11],
            [4, 6, 2, 11],
            [10, 5, 13, 4],
            [7, 3, 10, 3],
        ],
        [
            [10, 1, 1, 12],
            [14, 15, 0, 5],
            [14, 11, 9, 0],
            [12, 10, 10, 9],
            [4, 14, 9, 3],
        ],
    ]
)

In [4]:
simboard = DynapSE1SimBoard(5)
Iw_0 = simboard.Iw_0
Iw_1 = simboard.Iw_1
Iw_2 = simboard.Iw_2
Iw_3 = simboard.Iw_3



In [5]:
def bit_select(bitmask: np.ndarray) -> jnp.DeviceArray:
    """
    bit_select apply 4-bit mask to select bits
        
        0001 -> selected bit: 0
        1000 -> selected bit: 3
        0101 -> selected bit 0 and 2

    :param bitmask: Binary mask to select (shape,)
    :type bitmask: np.ndarray
    :return: an array of indices of selected bits (4,shape)
    :rtype: jnp.DeviceArray
    """
    bits = range(4)  # [0,1,2,3]
    bit_pattern = lambda n: (1 << n)  # 2^n

    # Indexes of the IDs to be selected in bits list
    idx = jnp.array([bitmask & bit_pattern(bit) for bit in bits], dtype=bool)
    return idx

In [6]:
# To broadcast on the post-synaptic neurons : pre, post, gate -> [(bits), post, pre, gate].T
bits_trans = bit_select(bitmask.transpose(1, 0, 2)).T

In [7]:
bits_trans.shape

(4, 3, 5, 4)

In [8]:
def weight_matrix(
    Iw_0: jnp.DeviceArray,
    Iw_1: jnp.DeviceArray,
    Iw_2: jnp.DeviceArray,
    Iw_3: jnp.DeviceArray,
    CAM: np.ndarray, # assume only one
    bitmask: np.ndarray,
):

    # post, bits
    Iw = jnp.column_stack((Iw_0, Iw_1, Iw_2, Iw_3))

    # To broadcast on the post-synaptic neurons 
    # pre, post, gate -> [(bits), post, pre, gate].T -> gate, pre, post, bits
    bits_trans = bit_select(bitmask.transpose(1, 0, 2)).T

    w_rec = jnp.sum(bits_trans * Iw, axis=-1).transpose(1, 2, 0)
    
    return w_rec

In [9]:
np.stack((bitmask, bitmask)).shape

(2, 3, 5, 4)

In [15]:
a = weight_matrix(Iw_0, Iw_1, Iw_2, Iw_3, CAM, bitmask)

In [11]:
from rockpool.devices.dynapse.autoencoder import WeightConfig  

In [12]:
wc = WeightConfig(None, Iw_0, Iw_1, Iw_2, Iw_3, bitmask)

HEY


In [16]:
(a ==wc.weights).all()

HEY
