# 🛠 Low-level `Module` API

## Constructing a `Module`

In [1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Useful imports
try:
    from rich import print
except:
    pass

# - Example of constructing a module
from rockpool.nn.modules import Rate
import numpy as np

# - Construct a Module with 4 neurons
mod = Rate(4)
print(mod)

In [2]:
# - Construct a Module with concrete parameters
mod = Rate(4, tau=np.ones(4))
print(mod)

## Evolving a `Module`

In [3]:
# - Generate and evolve over some input
T = 5
input = np.random.rand(T, mod.size_in)
output, _, _ = mod(input)
print(f"Output shape: {output.shape}")

In [4]:
# - Request the recorded state
output, _, recorded_state = mod(input, record=True)
print("Parameters:", recorded_state)

## Parameters, State and SimulationParameters

## Building a network with `Module` s

In [5]:
# - Build a simple network
from rockpool.nn.modules import Module
from rockpool.parameters import Parameter
from rockpool.nn.modules import RateJax


class ffwd_net(Module):
    # - Provide an `__init__` method to specify required parameters and modules
    #   Here you check, define and initialise whatever parameters and
    #   state you need for your module.
    def __init__(
        self,
        shape,
        *args,
        **kwargs,
    ):
        # - Call superclass initialisation
        #   This is always required for a `Module` class
        super().__init__(shape=shape, *args, **kwargs)

        # - Specify weights attribute
        #   We need a weights matrix for our input weights.
        #   We specify the shape explicitly, and provide an initialisation function.
        #   We also specify a family for the parameter, "weights". This is used to
        #   query parameters conveniently, and is a good idea to provide.
        self.w_ffwd = Parameter(
            shape=self.shape,
            init_func=lambda s: np.zeros(s),
            family="weights",
        )

        # - Specify and a add submodule
        #   These will be the neurons in our layer, to receive the weighted
        #   input signals. This sub-module will be automatically configured
        #   internally, to specify the required state and parameters
        self.neurons = RateJax(self.shape[-1])

    # - The `evolve` method contains the internal logic of your module
    #   `evolve` takes care of passing data in and out of the module,
    #   and between sub-modules if present.
    def evolve(self, input_data, *args, **kwargs):
        # - Pass input data through the input weights
        x = input_data @ self.w_ffwd

        # - Pass the signals through the neurons
        x, _, _ = self.neurons(x)

        # - Return the module output
        return x, {}, {}

### Writing an `evolve()` method that returns state and record

In [6]:
def evolve(self, input_data, record: bool = False, *args, **kwargs):
    # - Initialise state and record dictionaries
    new_state = {}
    recorded_state = {}

    # - Pass input data through the input weights
    x = input_data @ self.w_ffwd

    # - Add an internal signal record to the record dictionary
    if record:
        recorded_state["weighted_input"] = x

    # - Pass the signals through the neurons, passing through the `record` argument
    x, submod_state, submod_record = self.neurons(x, record=record)

    # - Record the submodule state
    new_state.update("neurons", submod_state)

    # - Include the recorded state
    recorded_state.update("neurons", submod_record)

    # - Return the module output
    return x, new_state, recorded_state

## Inspecting a `Module`

In [7]:
# - Build a module for our network
my_mod = ffwd_net((4, 6))
print(my_mod)

In [8]:
# - Show module parameters
print("Parameters:", my_mod.parameters())

In [9]:
# - Show module state
print("State:", my_mod.state())

In [10]:
# - Return parameters from particular families
print("Module time constants:", my_mod.parameters("taus"))
print("Module weights:", my_mod.parameters("weights"))

In [11]:
# - Access parameters directly
print(".w_ffwd:", my_mod.w_ffwd)
print(".neurons.tau:", my_mod.neurons.tau)

## `Module` API reference