Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform Feedforward-Network + solver into a Recurrent-Network #109

Closed
SimiPixel opened this issue May 29, 2022 · 5 comments
Closed

Transform Feedforward-Network + solver into a Recurrent-Network #109

SimiPixel opened this issue May 29, 2022 · 5 comments

Comments

@SimiPixel
Copy link
Contributor

SimiPixel commented May 29, 2022

Hello Patrick,

let me first quickly motivate my feature request.
As a side-project i am currently working on Model-based optimal control. For e.g. a only partially-observable environment stateful agents are useful. So, suppose the action selection of an agent is given by the following method

def select_action(params, state, observation, time):
    apply = neural_network.apply
    state, action = apply(params, state, observation, time)
    return state, action

while True:
    action = select_action(..., observation, env.time)
    observation = env.step(action)

Typically, the apply-function is some recurrent neural network. Suppose the environment env is differentiable, because it is just some model of the environment (maybe another network). Now, i would like to replace the recurrent neural network with a feedforward network + solver without changing the API of the agent.

I was wondering if constructing the following is possible and sensible? I.e. i would like to transform a choice of Feedforward-Network + Solver into a Recurrent-Network.

def select_action(params, ode_state, observation, time):
    rhs = lambda x,u: neural_network.apply(params, x, u)
    solution, ode_state = odeint(ode_state, rhs, t1=time, u=(observation, time))
    return ode_state, solution.x(time)

I would like to emphasis that this select_action must remain differentiable: The x-output w.r.t the network parameters.

I would love to hear your input :) Anyways thank you in advance.

@patrick-kidger
Copy link
Owner

patrick-kidger commented May 29, 2022

Yep, this is definitely possible. Diffrax is intrinsically differentiable so no special care is needed. Untested, but perhaps something like the following:

import equinox as eqx
import diffrax as dfx
import jax.numpy as jnp

# wraps an MLP to concatenate state and observation together
class Func(eqx.Module):
   mlp: eqx.nn.MLP

   def __init__(self,  state_size, observation_size, width_size, depth, key):
       in_size = 1 + state_size + observation_size
       self.mlp = eqx.nn.MLP(in_size, state_size, width_size, depth, key=key)

   def __call__(self, t, state, observation):
       in_ = jnp.concatenate([t[None], state, observation])
       return self.mlp(in_)

func =  Func(...)
get_action = eqx.nn.MLP(...)

def select_action(model, state, observation, time):
   func, get_action = model
   prev_time, state = state
   term = dfx.ODETerm(func)
   # specify solver, dt0, stepsize_controller in whatever way you think appropriate
   solver = ...
   dt0 = ...
   stepsize_controller = ...
   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
                         stepsize_controller=stepsize_controller)
   (state,) = sol.ys
   action = get_action(state)
   state = (time, state)
   return state, action

It's not critical, but as a nice-to-have this uses Equinox as a convenient neural network library.

@SimiPixel
Copy link
Contributor Author

Thank you! Hopefully i will try to find some time the next days to try to implement this, but will definitely report back!

Maybe i am just getting too excited by access to a new, powerful tool without rewriting any old code. But i feel like offering a transform that does exactly that, so that given the feedforward-network, the solver, (and stepsize-controller), the measurement/action-mapping constructs a new differentiable function with the call-signature of your typical recurrent network is super nice.
It not only is conceptually easier (imo?) but also enables plug-and-play integration of neural ODEs in other domains where recurrent neural networks are often already well established (and especially from an API-standpoint).

@SimiPixel
Copy link
Contributor Author

Doesn't it make more sense to also include the solver state as part of the state of the differential equation?
So, i.e. in your example something like this?

stepsize_controller = ...
solver = ...
dt0 = ...
sampling_rate = 100 # Hz

def select_action(model, state, observation):
   func, get_action = model
   term = dfx.ODETerm(func)
   prev_time, state, solver_state, controller_state = state
   time = prev_time + 1/sampling_rate
   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, state, args=observation,
                         stepsize_controller=stepsize_controller, controller_state=controller_state, solver_state=solver_state)
   # update controller / solver state
   solver_state, controller_state = ...
   (state,) = sol.ys
   action = get_action(state)
   state = (time, state, solver_state, controller_state)
   return state, action

@patrick-kidger
Copy link
Owner

Yes, it does. If you do this then you should also pass made_jump. Either as the value from the previous step (if your vector field changes smoothly between steps) or as True (if your vector field has jumps between steps).

@SimiPixel
Copy link
Contributor Author

SimiPixel commented Jun 5, 2022

For completeness let me post my minimal working example.
Spoiler: This uses haiku, simply because i am already comfortable with that. Equinox probably would make this more beautiful :)

import diffrax as dfx
import jax.numpy as jnp
from acme.jax import utils 
from functools import partial 
import haiku as hk 
import jax 

sampling_rate = 100 # Hz
stepsize_controller = dfx.ConstantStepSize()
dt0 = 1/sampling_rate
solver = dfx.Euler()

action_size = 3
obs_size = 2
u_dummy = jnp.ones((action_size))

latent_state_size = 20
hidden_layers = [50,50]
@hk.without_apply_rng
@hk.transform_with_state
def rhs(t, u):
   t = jnp.atleast_1d(t)
   x = hk.get_state("x", shape=(latent_state_size,), init=jnp.zeros, dtype=jnp.float32)
   txu = utils.batch_concat((t,x,u), num_batch_dims=0)
   X = hk.nets.MLP(hidden_layers + [latent_state_size])(txu)
   return {"~": {"x": X}}

def haiku2dfx_rhs(rhs):
   def __rhs(params):
      def _rhs(t, x, u):
         # x is simply passed through
         dxdt, x = rhs(params, x, t, u)
         del x 
         return dxdt 
      return _rhs 
   return __rhs 

# this is not great / quite confusing
dxdt = haiku2dfx_rhs(rhs.apply)

@hk.without_apply_rng
@hk.transform  
def measurement_function(x):
   x = utils.batch_concat(x, num_batch_dims=0)
   C = hk.get_parameter("C", shape=(obs_size,x.shape[-1]), dtype=jnp.float32,
      init=lambda shape, dtype: jax.random.normal(hk.next_rng_key(), shape, dtype=dtype))
   return jnp.matmul(C, x)

def gen_init_solver_state(solver: dfx.AbstractSolver, params_rhs, x0):
   term = dfx.ODETerm(dxdt(params_rhs))
   t0=0.0 
   return solver.init(term, t0=t0, t1=t0+dt0, y0=x0, args=u_dummy)

def gen_init_controller_state():
   return dt0

saveat = dfx.SaveAt(t1=True,solver_state=True,controller_state=True,made_jump=True)

def step_fun_dynamics_to_time(params, state, u, time):
   prev_time, x, solver_state, controller_state, made_jump = state
   term = dfx.ODETerm(dxdt(params["rhs"]))

   sol = dfx.diffeqsolve(term, solver, prev_time, time, dt0, x, args=u,
                         stepsize_controller=stepsize_controller, saveat=saveat,
                         solver_state=solver_state, controller_state=controller_state,
                         made_jump=made_jump
                         )
   x = sol.ys
   x = utils.squeeze_batch_dim(x)
   state = (time, x, sol.solver_state, sol.controller_state, sol.made_jump)
   obs = measurement_function.apply(params["C"], x)
   return state, obs 

def step_fun_dynamics(params, state, u):
   prev_time = state[0]
   return step_fun_dynamics_to_time(params, state, u, prev_time + dt0)

@jax.jit 
@partial(jax.vmap, in_axes=(None, None, 0))
def unrolled_step_fun_dynamics(params, state, us):
   step_fun_dynamics_constraint = lambda state, u: step_fun_dynamics(params, state, u)
   state, obss = jax.lax.scan(step_fun_dynamics_constraint, init=state, xs=us)
   return obss

# initialise parameters
params_rhs, x0 = rhs.init(jax.random.PRNGKey(1), 0.0, u_dummy)
C = measurement_function.init(jax.random.PRNGKey(1), jnp.ones((latent_state_size,)))

params = {
    "rhs": params_rhs,
    "C": C 
}

# initialise step functions state
# (t0, x0, solver_state0, controller_state0, made_jump0)
made_jump0 = False 
init_state = (0.0, x0, gen_init_solver_state(solver, params_rhs, x0), gen_init_controller_state(), made_jump0)

# make prediction
bs=32
T=5.0 
uss = jnp.ones((bs, int(T*sampling_rate), action_size))
obsss = unrolled_step_fun_dynamics(params, init_state, uss)

Thanks Patrick for your help. It works perfectly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants