# [𝝺] Low-level functional API

## Functional evolution

First let's set up a module to play with:

In [1]:
# - Switch off warnings
import warnings

warnings.filterwarnings("ignore")

# - Rockpool imports
from rockpool.nn.modules import RateJax

# - Other useful imports
import numpy as np

try:
    from rich import print
except:
    pass

# - Construct a module
N = 3
mod = RateJax(N)

Now if we evolve the module, we get the outputs we expect:

In [2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)

In [3]:
print("output:", output)

In [4]:
print("new_state:", new_state)

In [5]:
print("mod.state:", mod.state())
print(mod.state()["x"], " != ", new_state["x"])

In [6]:
mod = mod.set_attributes(new_state)
print(mod.state()["x"], " == ", new_state["x"])

## Functional state and attribute setting

In [7]:
new_tau = mod.tau * 0.4
mod.tau = new_tau
print(new_tau, " == ", mod.tau)

In [8]:
params = mod.parameters()
params["tau"] = params["tau"] * 3.0

# - Note the functional calling style
mod = mod.set_attributes(params)

# - check that the attribute was set
print(params["tau"], " == ", mod.tau)

## Functional module reset

In [9]:
# - Reset the module state
mod = mod.reset_state()

# - Reset the module parameters
mod = mod.reset_parameters()

## Jax flattening