Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added eqx.experimental.noinline #126

Closed
wants to merge 3 commits into from
Closed

Added eqx.experimental.noinline #126

wants to merge 3 commits into from

Conversation

patrick-kidger
Copy link
Owner

@patrick-kidger patrick-kidger commented Jul 1, 2022

TL;DR: XLA sub-graphs!

Background

At present, JAX inlines the entire computation into a single XLA graph.

However, many scientific computing applications involve defining some modestly complicated function and then calling this function numerous times in different contexts. For example the vector field for an ODE must be traced 22 times when using the Kvaerno5 solver with automatic initial step size selection. (In ways that cannot easily be tidied up into a lax.scan or similar.)

Inlining without awareness of the repeated structure means the compiler is less efficient than it could be. I know of current examples with compile times about an hour long.

To support this use case there's been talk for quite a while about adding support to JAX or XLA for sub-programs, e.g.
google/jax#10974
google/jax#4572
google/jax#3847
google/jax#9298

no-inline decorator

Introducing equinox.experimental.noinline. This decorator places the function in a separate XLA computation graph, and links it up with the main one via jax.experimental.host_callback.call. The decorated function is only traced once; only one copy of it exists as a jaxpr; it is only compiled once. It can still be transformed via grad, vmap etc.

Running the included benchmark benchmarks/noinline.py we obtain a reduction in compile time 36 seconds -> 25 seconds, at the expense of a large runtime increase, 0.002 seconds -> 0.6 seconds. In practice that's still a good net saving (36 seconds -> 25.6 seconds) in the common use-case that you're developing + debugging your program.

Going further and switching the solver in the benchmark from dfx.Kvaerno5() to dfx.Dopri8() gives even better results: a compile time reduction of 23 seconds -> 8 seconds (!), with a runtime increase of 0.002 seconds -> 0.16 seconds. (I chose not to implement this as the default benchmark, just because I have other unrelated plans for improving the compile time of Dopri8.)

Limitations

  1. All the awful hackery and monkey-patching of JAX internals needed to make this work.
  2. This will only be efficient on the CPU. On the GPU it'll entail copies to and from the device. However I speculate that this isn't actually necessary, and may just be a limitation of our current use of host_callback?
  3. The runtime performance has cratered. I speculate a lot of that cost is due to the back-and-forth switching via Python, again due to our use of host_callback. (Flame graphs TBD.) Possibly also something GIL related?

Nonetheless, our main use-case is on the CPU and the overall compile-time improvements on the benchmark represent compile speed improvements of 1.5x to 3x, which is enough to make me happy. This is something we're looking forward to relying on as those 1+ hour compile times are really biting us.

CC

@shoyer and @YouJiacheng as I know you've both wanted this functionality in the past.
@FedericoV (+team) for using this.

Also speculatively tagging @gnecula (I have it in my head that you're behind host_callback?) @mattjj for possible interest. (Feel free to ignore this.)

1. Previously using a callable `out` parameter was experimental for
   `filter_vmap` -- because it monkey-patched JAX internals -- and
   unavailable for `filter_pmap`. It has now been updated to work for both,
   and using only the public JAX API. (Hurrah.)

2. Drive-by: Added `eqx.filter_eval_shape` as it looked like this was
   going to useful as part of implementing the previous feature. In the
   end this wasn't the case, but we get a new feature anyway.

3. Drive-by: Fixed a crash bug when filter-jit'ing a partial-wrapped
   function whilst using a PyTree as its filter spec (`fn`).
@YouJiacheng
Copy link

YouJiacheng commented Jul 3, 2022

I wonder why we cannot implement this functionality by register user function as primitive with help of mlir.lower_fun and mlir.cache_lowering.
Though it seems that in the MHLO=>HLO pass, mlir.cache_lowering wrapped function will still be expanded:
https://github.com/google/jax/blob/118db407f24f20a77ef491d504290f8c29d57d05/jax/_src/lax/lax.py#L1952-L1955

And it might be possible to generate binary then wrap it with CustomCall which will never be inlined, without the cost of host_callback.

@patrick-kidger
Copy link
Owner Author

Doing so with MLIR: right! After implementing the above I started wondering something similar. As an end goal, could we arrange to generate a single XlaComputation that all usages point at, via XLA's call operation? If need be wrap it in a fake while loop (same trick jax.checkpoint uses) to prevent any inlining analysis that might slow things down, but I assume just having a single computation object, rather than numerous, would already be enough to improve things sufficiently already.

The implications of this would be quite dramatic: it would enable a "bounded while loop" construct -- i.e. the core backbone of diffeq solvers, nonlinear optimisers, etc. This construct requires recursively tracing nested operations of the form lax.cond(pred, fn, jax.checkpoint(fn), *operands). The fact that they're nested means that at present we end up with an exponential explosion -- each use of the above involves tracing two copies of fn, each of which trace... etc. -- which makes this infeasible for compile-time performance. As such all current diffeq solvers / nonlinear optimisers / etc. currently hack around this and instead suffer substantial runtime performance penalties as a result. Being able to apply a noinline decorator to fn would avoid this issue, and thus drop the compile-time complexity from O(exp n) to just O(n). This would be amazing!

That said, a noinline implementation like the above sounds probably simple enough (if you know what you're doing), and as per the previous paragraph this feature request is clearly important enough, that I would have thought something like the above would exist in core JAX already by now. We're really bumping up against the edge of my knowledge wrt lowerings here, and perhaps this is impossible for some reason / entirely separate XLA graphs really are necessary.

@YouJiacheng
Copy link

YouJiacheng commented Jul 3, 2022

If wrap with rng fake while loop can successfully prevent inlining(really?), implement noinline should be easier.
But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.

@YouJiacheng
Copy link

My trial:

from functools import partial, lru_cache

import jax
from jax import core
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    print('traced')
    return x + 1

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(foo, multiple_results=False)))


def bar(x):
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO

Only inlined at post-compilation HLO level

traced
{ lambda ; a:i32[]. let
    b:i32[] = foo a
    c:i32[] = foo b
    d:i32[] = foo c
    e:i32[] = foo d
  in (e,) }
traced
module @jit_bar.0 {
  func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = call @foo(%arg0) : (tensor<i32>) -> tensor<i32>
    %1 = call @foo(%0) : (tensor<i32>) -> tensor<i32>
    %2 = call @foo(%1) : (tensor<i32>) -> tensor<i32>
    %3 = call @foo(%2) : (tensor<i32>) -> tensor<i32>
    return %3 : tensor<i32>
  }
  func private @foo(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<1> : tensor<i32>
    %1 = mhlo.add %arg0, %0 : tensor<i32>
    return %1 : tensor<i32>
  }
}

traced
HloModule jit_bar.1

ENTRY %main.22 (Arg_0.1: s32[]) -> s32[] {
  %Arg_0.1 = s32[] parameter(0)
  %constant_7 = s32[] constant(4)
  ROOT %add.3 = s32[] add(s32[] %Arg_0.1, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/add" source_file="*" source_line=28}
}

@YouJiacheng
Copy link

from functools import partial, lru_cache

import jax
from jax import core, lax
from jax.interpreters import mlir, xla, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten, tree_map
import jax.linear_util as lu
from jax._src.api_util import flatten_fun, shaped_abstractify

import numpy as np

def foo(x):
    print('traced')
    return x + 1

def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *(shaped_abstractify(a) for a in args_flat))
    return tree_unflatten(out_tree(), out)

cached_abstract_eval_fun = lru_cache(abstract_eval_fun)

def _dummy_result(aval: core.AbstractValue):
  if aval is core.abstract_token:
    return lax.create_token()
  else:
    return lax.full(aval.shape, 0, aval.dtype)

def wrapped(fun, *args, **kwargs):
    avals_out = abstract_eval_fun(fun, *args, **kwargs)
    dummies_like_result = tree_map(_dummy_result, avals_out)
    carry_init = (np.int32(0), dummies_like_result, args, kwargs)
    def cond(carry):
        counter, _, _, _ = carry
        return counter < lax.rng_uniform(np.int32(1), np.int32(2), shape=())

    def body(carry):
        counter, _, args, kwargs = carry
        results = fun(*args, **kwargs)
        return (counter + 1, results, args, kwargs)

    carry_res = lax.while_loop(cond, body, carry_init)
    return carry_res[1]

foo_p = core.Primitive('foo')
foo_p.def_impl(partial(xla.apply_primitive, foo_p))
foo_p.def_abstract_eval(lambda x: cached_abstract_eval_fun(foo, x))
mlir.register_lowering(foo_p, mlir.cache_lowering(mlir.lower_fun(lambda x: wrapped(foo, x), multiple_results=False)))


def bar(x):
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    x = foo_p.bind(x)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir())
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string())

XLA STILL inline all foo -- copying while loop 4 times!!!

HloModule jit_bar.1

%region_0.2 (arg_tuple.3: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.3 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.10 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=0
  %constant_7 = s32[] constant(1)
  %add.9 = s32[] add(s32[] %get-tuple-element.10, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.18 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.3), index=2
  %add.8 = s32[] add(s32[] %get-tuple-element.18, s32[] %constant_7), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.6 = (s32[], s32[], s32[]) tuple(s32[] %add.9, s32[] %add.8, s32[] %get-tuple-element.18)
}

%region_1.11 (arg_tuple.12: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.12 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.13 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.12), index=0
  %constant_776 = s32[] constant(1)
  ROOT %compare.20 = pred[] compare(s32[] %get-tuple-element.13, s32[] %constant_776), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.30 (arg_tuple.31: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.31 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.22 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=0
  %constant_35 = s32[] constant(1)
  %add.37 = s32[] add(s32[] %get-tuple-element.22, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.30 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.31), index=2
  %add.36 = s32[] add(s32[] %get-tuple-element.30, s32[] %constant_35), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.9 = (s32[], s32[], s32[]) tuple(s32[] %add.37, s32[] %add.36, s32[] %get-tuple-element.30)
}

%region_1.39 (arg_tuple.40: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.40 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.41 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.40), index=0
  %constant_782 = s32[] constant(1)
  ROOT %compare.48 = pred[] compare(s32[] %get-tuple-element.41, s32[] %constant_782), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.1 = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.58 (arg_tuple.59: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.59 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.37 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=0
  %constant_63 = s32[] constant(1)
  %add.65 = s32[] add(s32[] %get-tuple-element.37, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.45 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.59), index=2
  %add.64 = s32[] add(s32[] %get-tuple-element.45, s32[] %constant_63), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.13 = (s32[], s32[], s32[]) tuple(s32[] %add.65, s32[] %add.64, s32[] %get-tuple-element.45)
}

%region_1.67 (arg_tuple.68: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.68 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.69 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.68), index=0
  %constant_788 = s32[] constant(1)
  ROOT %compare.76 = pred[] compare(s32[] %get-tuple-element.69, s32[] %constant_788), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.2 = u64[2]{0} rng-get-and-update-state(), delta=1
}

%region_0.86 (arg_tuple.87: (s32[], s32[], s32[])) -> (s32[], s32[], s32[]) {
  %arg_tuple.87 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.49 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=0
  %constant_91 = s32[] constant(1)
  %add.93 = s32[] add(s32[] %get-tuple-element.49, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %get-tuple-element.57 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.87), index=2
  %add.92 = s32[] add(s32[] %get-tuple-element.57, s32[] %constant_91), metadata={op_name="jit(bar)/jit(main)/while/body/add" source_file="/home/jiacheng/***/sub.py" source_line=54}
  ROOT %tuple.16 = (s32[], s32[], s32[]) tuple(s32[] %add.93, s32[] %add.92, s32[] %get-tuple-element.57)
}

%region_1.95 (arg_tuple.96: (s32[], s32[], s32[])) -> pred[] {
  %arg_tuple.96 = (s32[], s32[], s32[]) parameter(0)
  %get-tuple-element.97 = s32[] get-tuple-element((s32[], s32[], s32[]) %arg_tuple.96), index=0
  %constant_794 = s32[] constant(1)
  ROOT %compare.104 = pred[] compare(s32[] %get-tuple-element.97, s32[] %constant_794), direction=LT, metadata={op_name="jit(bar)/jit(main)/while/cond/lt" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %rng-get-and-update-state.3 = u64[2]{0} rng-get-and-update-state(), delta=1
}

ENTRY %main.114 (Arg_0.1: s32[]) -> s32[] {
  %constant_379 = s32[] constant(0)
  %copy.24 = s32[] copy(s32[] %constant_379)
  %copy.19 = s32[] copy(s32[] %copy.24)
  %copy.12 = s32[] copy(s32[] %copy.24)
  %copy.13 = s32[] copy(s32[] %copy.24)
  %copy.6 = s32[] copy(s32[] %copy.24)
  %copy.7 = s32[] copy(s32[] %copy.24)
  %copy = s32[] copy(s32[] %copy.24)
  %copy.1 = s32[] copy(s32[] %copy.24)
  %Arg_0.1 = s32[] parameter(0)
  %tuple.4 = (s32[], s32[], s32[]) tuple(s32[] %copy, s32[] %copy.1, s32[] %Arg_0.1)
  %while.0 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.4), condition=%region_1.11, body=%region_0.2, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.21 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.0), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.7 = (s32[], s32[], s32[]) tuple(s32[] %copy.6, s32[] %copy.7, s32[] %get-tuple-element.21)
  %while.1 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.7), condition=%region_1.39, body=%region_0.30, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.36 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.1), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.11 = (s32[], s32[], s32[]) tuple(s32[] %copy.12, s32[] %copy.13, s32[] %get-tuple-element.36)
  %while.2 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.11), condition=%region_1.67, body=%region_0.58, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  %get-tuple-element.48 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.2), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
  %tuple.14 = (s32[], s32[], s32[]) tuple(s32[] %copy.24, s32[] %copy.19, s32[] %get-tuple-element.48)
  %while.3 = (s32[], s32[], s32[]) while((s32[], s32[], s32[]) %tuple.14), condition=%region_1.95, body=%region_0.86, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}, backend_config="{\"known_trip_count\":{\"n\":\"1\"}}"
  ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[], s32[], s32[]) %while.3), index=1, metadata={op_name="jit(bar)/jit(main)/while[cond_nconsts=0 body_nconsts=0]" source_file="/home/jiacheng/***/sub.py" source_line=54}
}

@patrick-kidger
Copy link
Owner Author

If wrap with rng fake while loop can successfully prevent inlining(really?), implement noinline should be easier. But I guess this trick can only prevent computation from moving outside. XLA can still clone the code at each call site.

Right, so I think the fake while loop would prevent inlining at the HLO->compile-to-whatever-backend level, i.e. prevent optimisation of the function wrt its calling context. Indeed I didn't mean that it would prevent the cloning issue.

XLA STILL inline all foo -- copying while loop 4 times!!!

Bother. Looking at mlir.cache_lowering, it looks like this does exactly what I was hoping might work -- producing only a single function object and then re-using it whenever possible. That's really unfortunate that JAX/XLA seems to be limited like this.

Perhaps the way forward really is to implement a "better" version of host_callback that fits this use-case. The new work on effectful jaxprs might make this a lot easier -- these have so many use cases so I'm very excited to be getting algebraic effects in JAX -- since I assume this should remove a lot of the rewriting/hackery/etc. that host_callback seems to bring in.

Base automatically changed from no-experimental-filter-vmap to main July 5, 2022 15:29
@YouJiacheng
Copy link

YouJiacheng commented Jul 6, 2022

Hi! - I prototype a (possibly) simpler version than host_callback directly using mlir.emit_python_callback.
But xla_python_gpu_callback using numpy to transfer data.
Thanks to jaxlib-0.3.11 adding xla_python_gpu_callback, we don't need to use complex outfeed, at the cost of performance.
(I found it is 20% faster than host_callback.call for (1024, 1024) float32 identity function, but spend nearly 3x time for (8, 1024, 1024) or larger float32)

from functools import lru_cache

import jax
from jax import core
from jax.interpreters import mlir, partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten
import jax.linear_util as lu
from jax._src.api_util import flatten_fun

def foo(x):
    print('traced')
    print(type(x))
    return x + 1

def foo_lowered(x):
    print(type(x))
    return x + 1

@lru_cache
def abstract_eval_fun(fun, *args, **kwargs):
    args_flat, in_tree = tree_flatten((args, kwargs))
    wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
    out = pe.abstract_eval_fun(wrapped_fun.call_wrapped, *args_flat)
    return tree_unflatten(out_tree(), out)

callback_p = core.Primitive('callback')
callback_p.def_impl(lambda *args, callback, **_: callback(*args))
callback_p.def_abstract_eval(lambda *args, callback, **_: abstract_eval_fun(callback, *args))

def callback_lowering(ctx, *args, callback, callback_lowered):
    try:
        iter(abstract_eval_fun(callback, *ctx.avals_in))
    except TypeError:
        f = lambda *args: (callback_lowered(*args),)
    else:
        f = callback_lowered
    result, keepalive = mlir.emit_python_callback(ctx.module_context.platform, f, args, ctx.avals_in, ctx.avals_out, False)
    ctx.module_context.add_keepalive(keepalive)
    return result

mlir.register_lowering(callback_p, mlir.cache_lowering(callback_lowering))


def bar(x):
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    x = callback_p.bind(x, callback=foo, callback_lowered=foo_lowered)
    return x

print(jax.make_jaxpr(bar)(1))
print(jax.jit(bar)(1))
print(jax.jit(bar).lower(1).compiler_ir()) # MHLO
print(jax.jit(bar).lower(1).compile().compiler_ir()[0].to_string()) # post-compilation HLO
traced
<class 'jax.interpreters.partial_eval.DynamicJaxprTracer'>
{ lambda ; a:i32[]. let
    b:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] a
    c:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] b
    d:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] c
    e:i32[] = callback[
      callback=<function foo at 0x7f78a3d13790>
      callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>
    ] d
  in (e,) }
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
5
module @jit_bar.1 {
  func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = call @callback(%arg0) : (tensor<i32>) -> tensor<i32>
    %1 = call @callback(%0) : (tensor<i32>) -> tensor<i32>
    %2 = call @callback(%1) : (tensor<i32>) -> tensor<i32>
    %3 = call @callback(%2) : (tensor<i32>) -> tensor<i32>
    return %3 : tensor<i32>
  }
  func.func private @callback(%arg0: tensor<i32>) -> tensor<i32> {
    %0 = mhlo.constant dense<94158890291952> : tensor<i64>
    %1 = "mhlo.custom_call"(%0, %arg0) {api_version = 2 : i32, backend_config = "94158890291952", call_target_name = "xla_python_gpu_callback", called_computations = [], has_side_effect = false} : (tensor<i64>, tensor<i32>) -> tuple<tensor<i32>>
    %2 = "mhlo.get_tuple_element"(%1) {index = 0 : i32} : (tuple<tensor<i32>>) -> tensor<i32>
    return %2 : tensor<i32>
  }
}

HloModule jit_bar.2, entry_computation_layout={(s32[])->s32[]}

ENTRY %main.26 (Arg_0.1: s32[]) -> s32[] {
  %constant_0 = s64[] constant(94158882601360)
  %Arg_0.1 = s32[] parameter(0)
  %custom-call.0 = (s32[]) custom-call(s64[] %constant_0, s32[] %Arg_0.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.0 = s32[] get-tuple-element((s32[]) %custom-call.0), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.1 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.0), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.1 = s32[] get-tuple-element((s32[]) %custom-call.1), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.2 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.1), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  %get-tuple-element.2 = s32[] get-tuple-element((s32[]) %custom-call.2), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
  %custom-call.3 = (s32[]) custom-call(s64[] %constant_0, s32[] %get-tuple-element.2), custom_call_target="xla_python_gpu_callback", api_version=API_VERSION_STATUS_RETURNING, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}, backend_config="94158882601360"
  ROOT %get-tuple-element.3 = s32[] get-tuple-element((s32[]) %custom-call.3), index=0, metadata={op_name="jit(bar)/jit(main)/callback[callback=<function foo at 0x7f78a3d13790> callback_lowered=<function foo_lowered at 0x7f78a3d13ca0>]" source_file="/home/jiacheng/***/sub.py" source_line=45}
}

@gnecula
Copy link

gnecula commented Jul 6, 2022

BTW, host_callback uses the CustomCall implementation already by default on CPU (no outfeed).

There is a new mechanism being developed to replace host_callback, but it is not yet ready.

In any case, using host_callback here seems more like a workaround than the best solution.

@patrick-kidger
Copy link
Owner Author

Closing in favour of #218, which adds equinox.internal.noinline. This offers a heavily improved version of this.

  • Much faster.
  • Uses a custom primitive directly, instead of monkey-patching host_callback.call.
  • Offers the ability to only recompile only the part of a computation graph that has changed, without needing to recompile the entire graph. For example:
    def abstract(x, y):
        return jnp.broadcast_arrays(x, y)[0]
    
    def f(x, y):
        print("Compiling f!")
        return x + y
    
    def g(x, y):
        print("Compiling g!")
        return x * y
    
    f = noinline(f, abstract)
    g = noinline(g, abstract)
    
    def call(fn, x, y):
        print("Compiling call!")
        return fn(x, y)
    
    call = eqx.filter_jit(call)
    call(f, 1, 1)  # Compiling call! Compiling f!
    call(g, 1, 1)  # Compiling g!  [But does not recompile call!]

@patrick-kidger patrick-kidger deleted the noinline branch November 2, 2022 23:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants