Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNorm raises "TypeError: Expected a callable value, got inf" #234

Closed
ciupakabra opened this issue Nov 21, 2022 · 7 comments
Closed

BatchNorm raises "TypeError: Expected a callable value, got inf" #234

ciupakabra opened this issue Nov 21, 2022 · 7 comments
Labels
bug Something isn't working question User queries

Comments

@ciupakabra
Copy link

I'm getting weird errors with BatchNorm. One example is the code below, where ODE parameters are optimized and the drift is a neural network with some BatchNorm layers. The error thrown is TypeError: Expected a callable value, got inf when vmapping a loss function.

I wasn't able to reduce this to something without diffrax -- double vmaps or vmapping a scan doesn't raise exceptions. I'm guessing this is an equinox issue since without batchnorm (bn=False) there are no problems.

import optax
import jax
import jax.numpy as jnp
import jax.random as jrandom
import jax.nn as jnn
import equinox as eqx
import diffrax as dx
from tqdm import tqdm

def integrate(drift, num_steps, dim, key):

    def f(t, y, args):
        return drift(jnp.concatenate([t[None], y]))

    drift_term = dx.ODETerm(f)
    solver = dx.Euler()
    y0 = jnp.zeros(dim)

    ts = jnp.linspace(0, 1, num_steps + 1)
    saveat = dx.SaveAt(ts=ts)

    sol = dx.diffeqsolve(
        drift_term,
        solver,
        0,
        1,
        1/num_steps,
        y0,
        saveat=saveat,
        max_steps=num_steps + 1,
    )

    return sol.ys

def loss(drift, num_steps, dim, key):
    path = integrate(drift, num_steps, dim, key)
    final = path[-1]
    loss = jnp.sum(final**2)
    return loss

@eqx.filter_value_and_grad
def loss_mean(drift, num_steps, dim, key, batch_size):
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
    key = jrandom.split(key, batch_size)
    return jnp.mean(loss_vmapped(drift, num_steps, dim, key))


class Network(eqx.Module):
    net: eqx.Module

    def __init__(self, in_size, out_size, width, depth, *, key, bn=True):

        keys = jrandom.split(key, depth + 1)
        layers = []
        if depth == 0:
            layers.append(eqx.nn.Linear(in_size, out_size, key=keys[0]))
        else:
            layers.append(eqx.nn.Linear(in_size, width, key=keys[0]))
            if bn:
                layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
            for i in range(depth - 1):
                layers.append(eqx.nn.Linear(width, width, key=keys[i + 1]))
                if bn: 
                    layers.append(eqx.experimental.BatchNorm(width, axis_name="batch"))
                layers.append(eqx.nn.Lambda(jnn.relu))
            layers.append(eqx.nn.Linear(width, out_size, key=keys[-1]))

        self.net = eqx.nn.Sequential(layers)

    def __call__(self, x):
        return self.net(x)


if __name__=="__main__":

    key = jrandom.PRNGKey(0)

    init_drift_key, train_key = jrandom.split(key, 2)

    dim = 500

    drift = Network(dim + 1, dim, 300, 2, key=init_drift_key, bn=True)

    optimizer = optax.adamw(1e-4)
    opt_state = optimizer.init(eqx.filter(drift, eqx.is_inexact_array))
    
    @eqx.filter_jit
    def make_step(drift, num_steps, dim, key, batch_size, opt_state):
        loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
        updates, opt_state = optimizer.update(
            grads, opt_state, eqx.filter(drift, eqx.is_inexact_array)
        )
        drift = eqx.apply_updates(drift, updates)
        return loss, drift, opt_state


    for step in tqdm(range(100)):
        step_key = jrandom.fold_in(train_key, step)
        loss, drift, opt_state = make_step(
            drift, 10, dim, step_key, 32, opt_state
        )
(env) andrius:/home/andrius/repos/test% python test.py  
  1%|▏                                                                                                                                                                                                                       | 1/100 [00:01<03:14,  1.97s/it]
Traceback (most recent call last):
  File "/home/andrius/repos/test/test.py", line 99, in <module>
    loss, drift, opt_state = make_step(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
    return __self._fun_wrapper(False, args, kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
    dynamic_out, static_out = self._cached(dynamic, static)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 622, in cache_miss
    execute = dispatch._xla_call_impl_lazy(fun_, *tracers, **params)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl_lazy
    return xla_callable(fun, device, backend, name, donated_invars, keep_unused,
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 303, in memoized_fun
    ans = call(fun, *args)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 359, in _xla_callable_uncached
    return lower_xla_callable(fun, device, backend, name, donated_invars, False,
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/dispatch.py", line 445, in lower_xla_callable
    jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2077, in trace_to_jaxpr_final2
    jaxpr, out_type, consts = trace_to_subjaxpr_dynamic2(fun, main, debug_info)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 2027, in trace_to_subjaxpr_dynamic2
    ans = fun.call_wrapped(*in_tracers_)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/home/andrius/repos/test/test.py", line 89, in make_step
    loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1167, in value_and_grad_f
    ans, vjp_py = _vjp(f_partial, *dyn_args, reduce_axes=reduce_axes)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 2656, in _vjp
    out_primal, out_vjp = ad.vjp(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 135, in vjp
    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/ad.py", line 124, in linearize
    jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/interpreters/partial_eval.py", line 767, in trace_to_jaxpr_nounits
    jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/linear_util.py", line 167, in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "/home/andrius/repos/test/test.py", line 43, in loss_mean
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 1647, in vmap
    _check_callable(fun)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/jax/_src/api.py", line 181, in _check_callable
    raise TypeError(f"Expected a callable value, got {fun}")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Expected a callable value, got inf

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/andrius/repos/test/test.py", line 99, in <module>
    loss, drift, opt_state = make_step(
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 82, in __call__
    return __self._fun_wrapper(False, args, kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 78, in _fun_wrapper
    dynamic_out, static_out = self._cached(dynamic, static)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/jit.py", line 30, in fun_wrapped
    out = fun(*args, **kwargs)
  File "/home/andrius/repos/test/test.py", line 89, in make_step
    loss, grads = loss_mean(drift, num_steps, dim, key, batch_size)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 40, in __call__
    return fun_value_and_grad(diff_x, nondiff_x, *args, **kwargs)
  File "/home/andrius/repos/test/env/lib/python3.10/site-packages/equinox/grad.py", line 37, in fun_value_and_grad
    return __self._fun(_x, *_args, **_kwargs)
  File "/home/andrius/repos/test/test.py", line 43, in loss_mean
    loss_vmapped = jax.vmap(loss, (None, None, None, 0), 0, axis_name="batch")
TypeError: Expected a callable value, got inf
(env) andrius:/home/andrius/repos/test% pip list
Package           Version
----------------- ---------------------
absl-py           1.3.0
chex              0.1.5
diffrax           0.2.2
dm-tree           0.1.7
equinox           0.9.2
jax               0.3.25
jaxlib            0.3.25+cuda11.cudnn82
jaxtyping         0.2.8
numpy             1.23.5
opt-einsum        3.3.0
optax             0.1.4
pi                0.1.2
pip               22.3.1
scipy             1.9.3
setuptools        65.5.0
toolz             0.12.0
tqdm              4.64.1
typeguard         2.13.3
typing_extensions 4.4.0
wheel             0.37.1
@uuirs
Copy link
Contributor

uuirs commented Nov 22, 2022

loss is redefined in here

    loss, drift, opt_state = make_step(
        drift, 10, dim, step_key, 32, opt_state
    )

rename loss function to loss_fn, it works.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 22, 2022

A @uuirs points out, you've redefined the variable loss from a function to an array.

That said, it's a little odd that you're only hitting this when using BatchNorm. Is BatchNorm triggering a recompilation, somehow?

@patrick-kidger patrick-kidger added the question User queries label Nov 22, 2022
@ciupakabra
Copy link
Author

ciupakabra commented Nov 23, 2022

Yeah it's odd that this happened in such a way and you're right -- I checked and with bn=False it compiles once, whereas with bn=True it compiles twice.

@ciupakabra
Copy link
Author

ciupakabra commented Nov 24, 2022

Had some time to debug this tonight. I think it's because the state of BatchNorm is first initialized with a pytree of a different PyTreeDef than what it stores once it actually gets a state, i.e. it's first initialized as an array:

set_state(self.first_time_index, jnp.array(True))

but once it calculates the running means and vars it sets it to a tuple of arrays:

set_state(self.state_index, lax.stop_gradient((running_mean, running_var)))

This changes the PyTreeDef of the StateIndex module, which then changes the hash of PyTreeDef of the whole model. And since the PyTreeDef is used as a static argument somewhere in equinox jitting logic, it gets recompiled.

I'm assuming the fix is a one line change in BatchNorm initialization?

@patrick-kidger
Copy link
Owner

Aren't these different state indices?

@ciupakabra
Copy link
Author

Sorry, you're right, for initialization I meant to reference:

self._state = None

To be fair, even then I don't understand some of the behaviour. A small working example of the bug above is:

import jax
import jax.random as jrandom
import equinox as eqx

@eqx.filter_jit
def fun(bn, inp):
    print("Compiling!")
    return jax.vmap(bn, axis_name="batch")(inp)

def info(bn):
    print(f"bn.state_index._state: {bn.state_index._state}") # prints None all the time
    children, aux = bn.state_index.tree_flatten() # children should be dynamic_field_values and so children[0] should be the value of _state 
    print(f"children[0] of bn.state_index flatten: {children[0]}") # prints a tuple of arrays after 1st call to fun


if __name__=="__main__":
    bn = eqx.experimental.BatchNorm(10, axis_name="batch")
    inp = jrandom.normal(jrandom.PRNGKey(0), (32, 10))

    info(bn)
    fun(bn, inp)
    info(bn)
    fun(bn, inp)
    info(bn)
    fun(bn, inp)

which outputs

bn.state_index._state: None
children[0] of bn.state_index flatten: None
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954,  0.05969752,
             -0.11307174,  0.0944065 , -0.14875616, -0.05194131,
              0.10097986,  0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
             0.77187854, 1.38814   , 0.8497227 , 1.1132355 , 0.86574566],            dtype=float32))
Compiling!
bn.state_index._state: None
children[0] of bn.state_index flatten: (DeviceArray([-0.32395738, -0.21207958, -0.31645954,  0.05969752,
             -0.11307174,  0.0944065 , -0.14875618, -0.05194131,
              0.10097986,  0.25392908], dtype=float32), DeviceArray([0.92587423, 0.95594984, 1.0194211 , 0.84475476, 0.76749337,
             0.77187854, 1.38814   , 0.8497227 , 1.1132355 , 0.86574566],            dtype=float32))

So you can clearly see that at least when accessing _state through tree_flatten of StateIndex, _state changes from None to a 2-tuple of arrays, changing the PyTreeDef of the static args.

I don't understand why bn.state_index._state is constantly outputing None though, I thought this would set it to a new value:

object.__setattr__(self, "_state", new_state)

but I guess it gets deleted somewhere in between the two prints by _delete_smuggled_state.

@patrick-kidger
Copy link
Owner

Closing as eqx.experimental.BatchNorm is now available (in theory without bugs) as eqx.nn.BatchNorm.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question User queries
Projects
None yet
Development

No branches or pull requests

3 participants