# Low-level functional API

## Functional evolution

First let's set up a module to play with:

In [1]:
# - Rockpool imports
from rockpool.nn.modules import RateEulerJax

# - Other useful imports
import numpy as np

try:
    from rich import print
except:
    pass

# - Construct a module
N = 3
mod = RateEulerJax(N)



Now if we evolve the module, we get the outputs we expect:

In [2]:
# - Set up some input
T = 10
input = np.random.rand(T, N)
output, new_state, record = mod(input)

In [3]:
print('output:', output)

In [4]:
print('new_state:', new_state)

In [9]:
print('mod.state:', mod.state())
print(mod.state()['neur_state'], ' != ', new_state['neur_state'])

In [10]:
mod = mod.set_attributes(new_state)
print(mod.state()['neur_state'], ' == ', new_state['neur_state'])

## Functional state and attribute setting

In [6]:
new_tau = mod.tau * .4
mod.tau = new_tau
print(new_tau, ' == ', mod.tau)

In [7]:
params = mod.parameters()
params['tau'] = params['tau'] * 3.

# - Note the functional calling style
mod = mod.set_attributes(params)

# - check that the attribute was set
print(params['tau'], ' == ', mod.tau)

## Functional module reset

In [11]:
# - Reset the module state
mod = mod.reset_state()

# - Reset the module parameters
mod = mod.reset_parameters()

## Jax flattening