Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Numba, Cython, and JAX for Python-implemented Ops #10

Closed
brandonwillard opened this issue Jun 7, 2020 · 23 comments · Fixed by #21
Closed

Use Numba, Cython, and JAX for Python-implemented Ops #10

brandonwillard opened this issue Jun 7, 2020 · 23 comments · Fixed by #21
Labels
enhancement New feature or request

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Jun 7, 2020

Here's an old example of a Numba-enabled Theano Op. We can most likely do something similar for Cython and JAX, as well.

As an extension to that example, it would be nice to have the compiled function be the Op.perform method itself; that way, we could attempt to compile existing Ops with little-to-no changes.

@brandonwillard brandonwillard added the enhancement New feature or request label Jun 7, 2020
@dfm
Copy link
Contributor

dfm commented Jun 7, 2020

I've blogged about using TensorFlow and PyTorch ops in Theano. These might be useful for reference as well.

@brandonwillard brandonwillard changed the title Use Numba and JAX for Python-implemented Ops Use Numba, Cython, and JAX for Python-implemented Ops Jun 8, 2020
@brandonwillard
Copy link
Member Author

@dfm, those are cool examples. Do you happen to have any experience with the underlying compilation subsystem (e.g. Linker)? I would like to make good use of these compilation mechanisms if we do something like this.

For instance, @nouiz mentions that some Python overhead can be removed by calling the Numba-compiled code from C in Theano. I'm assuming that such functionality would involve those subsystems.

@dfm
Copy link
Contributor

dfm commented Jun 8, 2020

Interesting! I have written lots of custom C-ops for Theano, but I don't have much experience with the numba side. This might be beyond my expertise!

What sort of use case are you imagining for something like that?

@brandonwillard
Copy link
Member Author

The use-case is simply efficient interoperability with Numba, Cython, and JAX-compatible code within Theano. It would also be nice to explore the idea of automatically compiling Python Op.perform methods, as well.

@brandonwillard
Copy link
Member Author

I've put together a prototype Linker class that demonstrates how Theano graphs can be compiled using JAX and evaluated—alongside Theano's normal Python and C implementations—in this Gist.

It only handles Composite Theano Ops right now, but that covers a lot of plain NumPy-compatible graphs. Also, I didn't get the output quite right (e.g. I think it wraps the results in an extra list), but it's still good enough to show that this idea is absolutely reasonable and doesn't require extreme efforts.

The Linker approach seems like the best, because it doesn't require any changes to existing Ops and it's the right context for compiling entire graphs/subgraphs—instead of single Ops—and I'm assuming that's necessary for any such compilation to be worthwhile.

@twiecki
Copy link
Contributor

twiecki commented Jul 26, 2020

That's awesome! I would be interested in how this compares in speed to theanos c backend. Also, what others Ops not currently supported could be tricky?

@brandonwillard
Copy link
Member Author

brandonwillard commented Jul 26, 2020

Well, I believe the advantage here is that we could use Python, C, and JAX (or Numba, Cython, etc.) together.

The Scan Op might be a tricky one, but this is mostly due to the cumbersome encoding of features (e.g. taps, initial values, etc.) employed by the Scan Op (a lot of which should be greatly simplified by the work referenced in #19). I don't know enough about it, but it seems like vmap might map to Elemwise, as well. There's also the Theano IncSubtensor operations that need to be mapped to JAX's index_[update|add], and control flow operators like IfElse. Last but not least, Theano's grad functionality needs to be bridged with JAX's.

Overall, most Theano Ops provide more than enough actionable runtime information to make just about any mapping to JAX-based operations, and, with only support for the Ops mentioned above, we would already have a great deal of coverage.

Finally, after going through all these parts of the codebase, it's clear that a good refactoring is needed (e.g. clarify the basic interface using abc—and perhaps dataclasses—features, combine related functionality into separate modules, start adding type annotations and useful docstrings).

@junpenglao
Copy link
Contributor

Last but not least, Theano's grad functionality needs to be bridged with JAX's.

This is my main question - IIUC Theano's grad is explicitly defined so we can just map to python/Jax operations? If so would there be any way to by pass that and use Jax grad? and any advantage/disadvantage of doing so?

@brandonwillard
Copy link
Member Author

brandonwillard commented Jul 26, 2020

As far as I can tell, we don't need to use the JAX gradient functions, but, if there's some advantage in doing so, it seems like we could.

The first thing that comes to mind regarding Theano grad vs. JAX's is that we would lose the ability to apply Theano graph optimizations to the gradient graphs, and that sounds like a big loss. To me, that's reason enough to forgo use of JAX's gradient functions.

@twiecki
Copy link
Contributor

twiecki commented Jul 27, 2020

Well, I believe the advantage here is that we could use Python, C, and JAX (or Numba, Cython, etc.) together.

I can see how that's a powerful story for CPU, but what about GPU/TPUs?

The Scan Op might be a tricky one, but this is mostly due to the cumbersome encoding of features (e.g. taps, initial values, etc.) employed by the Scan Op (a lot of which should be greatly simplified by the work referenced in #19).

I fully expect Scan to be the most difficult. However, it's use in PyMC3 is fairly limited (I see a usage for computing the grads though). What other difficult Ops might there be? We can leave convolutions for later. What about switch, IfElse, slinalg stuff, shared, set_subtensor and printing?

That we get graph optimizations for gradients through theano and hence don't need them from symjax makes total sense and I don't see anything we'd lose that way.

Finally, after going through all these parts of the codebase, it's clear that a good refactoring is needed (e.g. clarify the basic interface using abc—and perhaps dataclasses—features, combine related functionality into separate modules, start adding type annotations and useful docstrings).

I completely agree. If we're getting Theano fit for the 21st century we should allow ourselves to break backwards compatibility, abandon old Python versions, and potentially rename the package if the refactor is significant.

In general I strongly believe we should start with the simple stuff and stay very close to what PyMC3 needs. For example, an amazing first milestone would be the logp eval of a simple Binomial (or Normal) model and just do what is necessary for that and then slowly go more elaborate and see where it breaks down.

And a big question for me is still the speed.

@junpenglao
Copy link
Contributor

Another alternative is to have a XLA linker and compile to XLA-optimized kernels directly - turning theano into another XLA frontend

@nouiz
Copy link
Contributor

nouiz commented Jul 27, 2020

JAX use XLA. So I'm not sure what you would gain by going directly to it.

@twiecki
Copy link
Contributor

twiecki commented Jul 27, 2020

Thanks for chiming in @nouiz, any other thoughts on this approach?

@nouiz
Copy link
Contributor

nouiz commented Jul 27, 2020

It sound good. Just make sure to keep the execution engine in C like the CLinker. But I think the call to streamline does it. If the loop over each operation is in Python, the Python overhead will kill the performance.
Also, it would be best if the JAX "thunk" have a c interface to not pay the Python overhead.

But I'm not convinced the current version will give speed up vs Theano as JAX use XLA on GPU from my knowledge (I work on XLA).
The strength of XLA code gen vs Theano is that it can fuse more operations together. Not that the current fused code is better. So you would need to create bigger Composite that Theano can't compile but that JAX/XLA can compile to get some speed up.

But it looks a like a good proof of concept.

@brandonwillard
Copy link
Member Author

All right, I've provided a context that makes it much easier to add more Op conversions and included [Inc]Subtensor implementations and more: updated Gist.

In its current form, it will also compile the entire graph to JAX; however, if a single unsupported Op is encountered, then it falls back to Python. We could make it compile subgraphs, but first we need to determine when a subgraph is even worth compiling (e.g. a subgraph with a single node probably isn't worth running in JAX).

Anyway, this whole endeavour seems pretty clear-cut now. For that matter, after we add a conversion for Scan, this might actually be worth using!

@twiecki
Copy link
Contributor

twiecki commented Jul 31, 2020

This is incredible progress!

We could make it compile subgraphs, but first we need to determine when a subgraph is even worth compiling (e.g. a subgraph with a single node probably isn't worth running in JAX).

Is that because of the call overhead? I would hope we just get a c-extension-like callable like with Cython with minimal call overhead compared to Python.

How difficult do you think Scan would be? I saw JAX has a native While implementation.

@twiecki
Copy link
Contributor

twiecki commented Jul 31, 2020

I did some speed comparisons: https://gist.github.com/twiecki/38dc98197eed5594c5518a3971064c92

On small arrays, c mode is about 6x faster than JAX, but with larger arrays it seems to be about the same, which is a great starting point.

@twiecki
Copy link
Contributor

twiecki commented Jul 31, 2020

I updated the gist with a simple PyMC3 model, we can just pass the jax mode to the logp_dlogp_function method. It currently doesn't compile with:

<ipython-input-1-e78ca8561185>:270: UserWarning: JaxLinker could not JAXify graph: Could not find signature for jax_funcify: <Alloc, Apply>

I think iteratively solving those individual problems until it compiles will be a good route to take.

@brandonwillard
Copy link
Member Author

Is that because of the call overhead? I would hope we just get a c-extension-like callable like with Cython with minimal call overhead compared to Python.

No, it's just that I haven't set up the logic for computing subgraphs (and determining which to actually use).

How difficult do you think Scan would be? I saw JAX has a native While implementation.

I think it could be annoying to fully implement, but—at the end of the day—it's mostly a matter of porting the Python version of Scan.execute. That porting work might require logic to determine when a Scan fits the form of jax.lax.scan, jax.lax.reduce, jax.lax.fori_loop, jax.lax.map, jax.lax.while_loop, etc., or we might be able to use jax.ops.index_update and preserve most of the Scan logic as-is (I think the former will be better and/or easier, though).

I did some speed comparisons: https://gist.github.com/twiecki/38dc98197eed5594c5518a3971064c92

Awesome, much appreciated!

On small arrays, c mode is about 6x faster than JAX, but with larger arrays it seems to be about the same, which is a great starting point.

That seems to make sense to me. Where do we think we'd notice the most speed-up? Graphs that use loops that can be parallelized/fused? I don't have a GPU on hand to try, so can't do any testing with that.

I think iteratively solving those individual problems until it compiles will be a good route to take.

Yes, that was my plan, and it should be pretty straightforward. We can also go through the list of jax.lax.* operators and get all the low-hanging fruit.

@brandonwillard brandonwillard linked a pull request Jul 31, 2020 that will close this issue
@junpenglao
Copy link
Contributor

Tested @twiecki's gist, seems jax primitive like jax.vmap, jax.grad does not yet work:

jax.grad(theano_jax_fn)(*test_input_vals)

test_input_vals2 = [
    np.tile(np.arange(1000).astype(float), (1, 1000, 1)),
    np.tile(np.arange(1000).astype(float), (1, 1000, 1)),
]

jax.vmap(theano_jax_fn)(*test_input_vals2)

gives The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced error

full trace
---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
<ipython-input-30-241055d0c563> in <module>
      4 ]
      5 
----> 6 jax.vmap(theano_jax_fn)(*test_input_vals2)

~/miniconda3/lib/python3.7/site-packages/jax/api.py in batched_fun(*args)
    871     _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
    872     out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 873                               lambda: flatten_axes("vmap out_axes", out_tree(),
    874                                                    out_axes))
    875     return tree_unflatten(out_tree(), out_flat)

~/miniconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
     32   # executes a batched version of `fun` following out_dim_dests
     33   batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34   return batched_fun.call_wrapped(*in_vals)
     35 
     36 @lu.transformation_with_aux

~/miniconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    148     gen = None
    149 
--> 150     ans = self.f(*args, **dict(self.params, **kwargs))
    151     del args
    152     while stack:

~/miniconda3/lib/python3.7/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
    811                         s.storage[0] = s.type.filter(
    812                             arg, strict=s.strict,
--> 813                             allow_downcast=s.allow_downcast)
    814 
    815                     except Exception as e:

~/miniconda3/lib/python3.7/site-packages/theano/tensor/type.py in filter(self, data, strict, allow_downcast)
    148                     # data has to be converted.
    149                     # Check that this conversion is lossless
--> 150                     converted_data = theano._asarray(data, self.dtype)
    151                     # We use the `values_eq` static function from TensorType
    152                     # to handle NaN values.

~/miniconda3/lib/python3.7/site-packages/theano/misc/safe_asarray.py in _asarray(a, dtype, order)
     32         dtype = theano.config.floatX
     33     dtype = np.dtype(dtype)  # Convert into dtype object.
---> 34     rval = np.asarray(a, dtype=dtype, order=order)
     35     # Note that dtype comparison must be done by comparing their `num`
     36     # attribute. One cannot assume that two identical data types are pointers

~/miniconda3/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

~/miniconda3/lib/python3.7/site-packages/jax/core.py in __array__(self, *args, **kw)
    448            "JAX Tracer instance; in that case, you can instead write "
    449            "`jax.device_put(x)[idx]`.")
--> 450     raise Exception(msg)
    451 
    452   def __init__(self, trace: Trace):

Exception: Bad input argument to theano function with name "<ipython-input-3-da3732ec7fad>:10" at index 0 (0-based).  
Backtrace when that variable is created:

  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2858, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2886, in _run_cell
    return runner(coro)
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3063, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3254, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-a27f936a3f77>", line 4, in <module>
    x = tt.matrix('x')
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[1000,1000])>with<BatchTrace(level=2/0)>
  with val = array([[[  0.,   1.,   2., ..., 997., 998., 999.],
                     [  0.,   1.,   2., ..., 997., 998., 999.],
                     [  0.,   1.,   2., ..., 997., 998., 999.],
                     ...,
                     [  0.,   1.,   2., ..., 997., 998., 999.],
                     [  0.,   1.,   2., ..., 997., 998., 999.],
                     [  0.,   1.,   2., ..., 997., 998., 999.]]])
       batch_dim = 0.

This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`.

@brandonwillard
Copy link
Member Author

brandonwillard commented Aug 5, 2020

@junpenglao, I didn't know one could insert folds into GitHub comments!

Tested @twiecki's gist, seems jax primitive like jax.vmap, jax.grad does not yet work:

Yeah, in order to implement jax.grad, I believe we would have to go some other route. The theano_jax_fn you're using in that example looks to be the jax.jit-ed function; is that what jax.grad is supposed to take as an argument? I would think it would take the JAX graph implied by the function created here (i.e. the one passed to jax.jit). Same goes for jax.vmap.

I had originally thought to use jax.vmap in the Elemwise conversion, but I really don't know enough about these JAX functions to say when/where it's appropriate in the conversion. For me, this whole exercise is mostly about familiarizing myself with JAX and XLA.

@junpenglao
Copy link
Contributor

Ah that make sense - In that case would it be possible have a function to output a none jit function?

@brandonwillard
Copy link
Member Author

Ah that make sense - In that case would it be possible have a function to output a none jit function?

Yeah, you can construct a Theano FunctionGraph for your graph and pass it to jax_funcify (which uses the implementation here). That's how the JAX functions are constructed for the jax.jit call (see here).

twiecki pushed a commit that referenced this issue Sep 27, 2020
Closes #10.

Well, at least the JAX part, but Cython and Numba implementations can follow
very similar approaches and accomplish the same thing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants