New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added eqx.experimental.noinline
#126
Conversation
1. Previously using a callable `out` parameter was experimental for `filter_vmap` -- because it monkey-patched JAX internals -- and unavailable for `filter_pmap`. It has now been updated to work for both, and using only the public JAX API. (Hurrah.) 2. Drive-by: Added `eqx.filter_eval_shape` as it looked like this was going to useful as part of implementing the previous feature. In the end this wasn't the case, but we get a new feature anyway. 3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped function whilst using a PyTree as its filter spec (`fn`).
I wonder why we cannot implement this functionality by register user function as primitive with help of And it might be possible to generate binary then wrap it with |
Doing so with MLIR: right! After implementing the above I started wondering something similar. As an end goal, could we arrange to generate a single The implications of this would be quite dramatic: it would enable a "bounded while loop" construct -- i.e. the core backbone of diffeq solvers, nonlinear optimisers, etc. This construct requires recursively tracing nested operations of the form That said, a |
If wrap with rng fake while loop can successfully prevent inlining(really?), implement |
My trial: from functools import partial, lru_cache
import jax
from jax import core
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
print('traced')
return x + 1
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(foo, multiple_results=False)))
def bar(x):
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO Only inlined at post-compilation HLO level
|
from functools import partial, lru_cache
import jax
from jax import core, lax
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
import jax.linear_util as lu
from jax._src.api_util import flatten_fun, shaped_abstractify
import numpy as np
def foo(x):
print('traced')
return x + 1
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *(shaped_abstractify(a) for a in args_flat))
return tree_unflatten(out_tree(), out)
cached_abstract_eval_fun = lru_cache(abstract_eval_fun)
def _dummy_result(aval: core.AbstractValue):
if aval is core.abstract_token:
return lax.create_token()
else:
return lax.full(aval.shape, 0, aval.dtype)
def wrapped(fun, *args, **kwargs):
avals_out = abstract_eval_fun(fun, *args, **kwargs)
dummies_like_result = tree_map(_dummy_result, avals_out)
carry_init = (np.int32(0), dummies_like_result, args, kwargs)
def cond(carry):
counter, _, _, _ = carry
return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())
def body(carry):
counter, _, args, kwargs = carry
results = fun(*args, **kwargs)
return (counter + 1, results, args, kwargs)
carry_res = lax.while_loop(cond, body, carry_init)
return carry_res[1]
foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: cached_abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(lambda x: wrapped(foo, x), multiple_results=False)))
def bar(x):
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
x = foo_p.bind(x)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir())
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) XLA STILL inline all foo -- copying while loop 4 times!!!
|
Right, so I think the fake while loop would prevent inlining at the HLO->compile-to-whatever-backend level, i.e. prevent optimisation of the function wrt its calling context. Indeed I didn't mean that it would prevent the cloning issue.
Bother. Looking at Perhaps the way forward really is to implement a "better" version of |
Hi! - I prototype a (possibly) simpler version than from functools import lru_cache
import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun
def foo(x):
print('traced')
print(type(x))
return x + 1
def foo_lowered(x):
print(type(x))
return x + 1
@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
return tree_unflatten(out_tree(), out)
callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))
def callback_lowering(ctx, *args, callback, callback_lowered):
try:
iter(abstract_eval_fun(callback, *ctx.avals_in))
except TypeError:
f = lambda *args: (callback_lowered(*args),)
else:
f = callback_lowered
result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
ctx.module_context.add_keepalive(keepalive)
return result
mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))
def bar(x):
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
return x
print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
|
BTW, There is a new mechanism being developed to replace In any case, using |
Closing in favour of #218, which adds
|
TL;DR: XLA sub-graphs!
Background
At present, JAX inlines the entire computation into a single XLA graph.
However, many scientific computing applications involve defining some modestly complicated function and then calling this function numerous times in different contexts. For example the vector field for an ODE must be traced 22 times when using the Kvaerno5 solver with automatic initial step size selection. (In ways that cannot easily be tidied up into a
lax.scan
or similar.)Inlining without awareness of the repeated structure means the compiler is less efficient than it could be. I know of current examples with compile times about an hour long.
To support this use case there's been talk for quite a while about adding support to JAX or XLA for sub-programs, e.g.
google/jax#10974
google/jax#4572
google/jax#3847
google/jax#9298
no-inline decorator
Introducing
equinox.experimental.noinline
. This decorator places the function in a separate XLA computation graph, and links it up with the main one viajax.experimental.host_callback.call
. The decorated function is only traced once; only one copy of it exists as a jaxpr; it is only compiled once. It can still be transformed via grad, vmap etc.Running the included benchmark
benchmarks/noinline.py
we obtain a reduction in compile time 36 seconds -> 25 seconds, at the expense of a large runtime increase, 0.002 seconds -> 0.6 seconds. In practice that's still a good net saving (36 seconds -> 25.6 seconds) in the common use-case that you're developing + debugging your program.Going further and switching the solver in the benchmark from
dfx.Kvaerno5()
todfx.Dopri8()
gives even better results: a compile time reduction of 23 seconds -> 8 seconds (!), with a runtime increase of 0.002 seconds -> 0.16 seconds. (I chose not to implement this as the default benchmark, just because I have other unrelated plans for improving the compile time ofDopri8
.)Limitations
host_callback
?host_callback
. (Flame graphs TBD.) Possibly also something GIL related?Nonetheless, our main use-case is on the CPU and the overall compile-time improvements on the benchmark represent compile speed improvements of 1.5x to 3x, which is enough to make me happy. This is something we're looking forward to relying on as those 1+ hour compile times are really biting us.
CC
@shoyer and @YouJiacheng as I know you've both wanted this functionality in the past.
@FedericoV (+team) for using this.
Also speculatively tagging @gnecula (I have it in my head that you're behind
host_callback
?) @mattjj for possible interest. (Feel free to ignore this.)