New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transform Feedforward-Network + solver into a Recurrent-Network #109
Comments
Yep, this is definitely possible. Diffrax is intrinsically differentiable so no special care is needed. Untested, but perhaps something like the following: import equinox as eqx
import diffrax as dfx
import jax.numpy as jnp
# wraps an MLP to concatenate state and observation together
class Func(eqx.Module):
mlp: eqx.nn.MLP
def __init__(self, state_size, observation_size, width_size, depth, key):
in_size = 1 + state_size + observation_size
self.mlp = eqx.nn.MLP(in_size, state_size, width_size, depth, key=key)
def __call__(self, t, state, observation):
in_ = jnp.concatenate([t[None], state, observation])
return self.mlp(in_)
func = Func(...)
get_action = eqx.nn.MLP(...)
def select_action(model, state, observation, time):
func, get_action = model
prev_time, state = state
term = dfx.ODETerm(func)
# specify solver, dt0, stepsize_controller in whatever way you think appropriate
solver = ...
dt0 = ...
stepsize_controller = ...
sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
stepsize_controller=stepsize_controller)
(state,) = sol.ys
action = get_action(state)
state = (time, state)
return state, action It's not critical, but as a nice-to-have this uses Equinox as a convenient neural network library. |
Thank you! Hopefully i will try to find some time the next days to try to implement this, but will definitely report back! Maybe i am just getting too excited by access to a new, powerful tool without rewriting any old code. But i feel like offering a transform that does exactly that, so that given the feedforward-network, the solver, (and stepsize-controller), the measurement/action-mapping constructs a new differentiable function with the call-signature of your typical recurrent network is super nice. |
Doesn't it make more sense to also include the solver state as part of the state of the differential equation? stepsize_controller = ...
solver = ...
dt0 = ...
sampling_rate = 100 # Hz
def select_action(model, state, observation):
func, get_action = model
term = dfx.ODETerm(func)
prev_time, state, solver_state, controller_state = state
time = prev_time + 1/sampling_rate
sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
stepsize_controller=stepsize_controller, controller_state=controller_state, solver_state=solver_state)
# update controller / solver state
solver_state, controller_state = ...
(state,) = sol.ys
action = get_action(state)
state = (time, state, solver_state, controller_state)
return state, action |
Yes, it does. If you do this then you should also pass |
For completeness let me post my minimal working example. import diffrax as dfx
import jax.numpy as jnp
from acme.jax import utils
from functools import partial
import haiku as hk
import jax
sampling_rate = 100 # Hz
stepsize_controller = dfx.ConstantStepSize()
dt0 = 1/sampling_rate
solver = dfx.Euler()
action_size = 3
obs_size = 2
u_dummy = jnp.ones((action_size))
latent_state_size = 20
hidden_layers = [50,50]
@hk.without_apply_rng
@hk.transform_with_state
def rhs(t, u):
t = jnp.atleast_1d(t)
x = hk.get_state("x", shape=(latent_state_size,), init=jnp.zeros, dtype=jnp.float32)
txu = utils.batch_concat((t,x,u), num_batch_dims=0)
X = hk.nets.MLP(hidden_layers + [latent_state_size])(txu)
return {"~": {"x": X}}
def haiku2dfx_rhs(rhs):
def __rhs(params):
def _rhs(t, x, u):
# x is simply passed through
dxdt, x = rhs(params, x, t, u)
del x
return dxdt
return _rhs
return __rhs
# this is not great / quite confusing
dxdt = haiku2dfx_rhs(rhs.apply)
@hk.without_apply_rng
@hk.transform
def measurement_function(x):
x = utils.batch_concat(x, num_batch_dims=0)
C = hk.get_parameter("C", shape=(obs_size,x.shape[-1]), dtype=jnp.float32,
init=lambda shape, dtype: jax.random.normal(hk.next_rng_key(), shape, dtype=dtype))
return jnp.matmul(C, x)
def gen_init_solver_state(solver: dfx.AbstractSolver, params_rhs, x0):
term = dfx.ODETerm(dxdt(params_rhs))
t0=0.0
return solver.init(term, t0=t0, t1=t0+dt0, y0=x0, args=u_dummy)
def gen_init_controller_state():
return dt0
saveat = dfx.SaveAt(t1=True,solver_state=True,controller_state=True,made_jump=True)
def step_fun_dynamics_to_time(params, state, u, time):
prev_time, x, solver_state, controller_state, made_jump = state
term = dfx.ODETerm(dxdt(params["rhs"]))
sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, x, args=u,
stepsize_controller=stepsize_controller, saveat=saveat,
solver_state=solver_state, controller_state=controller_state,
made_jump=made_jump
)
x = sol.ys
x = utils.squeeze_batch_dim(x)
state = (time, x, sol.solver_state, sol.controller_state, sol.made_jump)
obs = measurement_function.apply(params["C"], x)
return state, obs
def step_fun_dynamics(params, state, u):
prev_time = state[0]
return step_fun_dynamics_to_time(params, state, u, prev_time + dt0)
@jax.jit
@partial(jax.vmap, in_axes=(None, None, 0))
def unrolled_step_fun_dynamics(params, state, us):
step_fun_dynamics_constraint = lambda state, u: step_fun_dynamics(params, state, u)
state, obss = jax.lax.scan(step_fun_dynamics_constraint, init=state, xs=us)
return obss
# initialise parameters
params_rhs, x0 = rhs.init(jax.random.PRNGKey(1), 0.0, u_dummy)
C = measurement_function.init(jax.random.PRNGKey(1), jnp.ones((latent_state_size,)))
params = {
"rhs": params_rhs,
"C": C
}
# initialise step functions state
# (t0, x0, solver_state0, controller_state0, made_jump0)
made_jump0 = False
init_state = (0.0, x0, gen_init_solver_state(solver, params_rhs, x0), gen_init_controller_state(), made_jump0)
# make prediction
bs=32
T=5.0
uss = jnp.ones((bs, int(T*sampling_rate), action_size))
obsss = unrolled_step_fun_dynamics(params, init_state, uss) Thanks Patrick for your help. It works perfectly. |
Hello Patrick,
let me first quickly motivate my feature request.
As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method
Typically, the
apply
-function is some recurrent neural network. Suppose the environmentenv
is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.
I would like to emphasis that this
select_action
must remain differentiable: The x-output w.r.t the network parameters.I would love to hear your input :) Anyways thank you in advance.
The text was updated successfully, but these errors were encountered: