-
-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use Numba, Cython, and JAX for Python-implemented Ops #10
Comments
I've blogged about using TensorFlow and PyTorch ops in Theano. These might be useful for reference as well. |
@dfm, those are cool examples. Do you happen to have any experience with the underlying compilation subsystem (e.g. For instance, @nouiz mentions that some Python overhead can be removed by calling the Numba-compiled code from C in Theano. I'm assuming that such functionality would involve those subsystems. |
Interesting! I have written lots of custom C-ops for Theano, but I don't have much experience with the numba side. This might be beyond my expertise! What sort of use case are you imagining for something like that? |
The use-case is simply efficient interoperability with Numba, Cython, and JAX-compatible code within Theano. It would also be nice to explore the idea of automatically compiling Python |
I've put together a prototype It only handles The |
That's awesome! I would be interested in how this compares in speed to theanos c backend. Also, what others Ops not currently supported could be tricky? |
Well, I believe the advantage here is that we could use Python, C, and JAX (or Numba, Cython, etc.) together. The Overall, most Theano Finally, after going through all these parts of the codebase, it's clear that a good refactoring is needed (e.g. clarify the basic interface using |
This is my main question - IIUC Theano's grad is explicitly defined so we can just map to python/Jax operations? If so would there be any way to by pass that and use Jax grad? and any advantage/disadvantage of doing so? |
As far as I can tell, we don't need to use the JAX gradient functions, but, if there's some advantage in doing so, it seems like we could. The first thing that comes to mind regarding Theano |
I can see how that's a powerful story for CPU, but what about GPU/TPUs?
I fully expect Scan to be the most difficult. However, it's use in PyMC3 is fairly limited (I see a usage for computing the grads though). What other difficult That we get graph optimizations for gradients through theano and hence don't need them from symjax makes total sense and I don't see anything we'd lose that way.
I completely agree. If we're getting Theano fit for the 21st century we should allow ourselves to break backwards compatibility, abandon old Python versions, and potentially rename the package if the refactor is significant. In general I strongly believe we should start with the simple stuff and stay very close to what PyMC3 needs. For example, an amazing first milestone would be the logp eval of a simple Binomial (or Normal) model and just do what is necessary for that and then slowly go more elaborate and see where it breaks down. And a big question for me is still the speed. |
Another alternative is to have a XLA linker and compile to XLA-optimized kernels directly - turning theano into another XLA frontend |
JAX use XLA. So I'm not sure what you would gain by going directly to it. |
Thanks for chiming in @nouiz, any other thoughts on this approach? |
It sound good. Just make sure to keep the execution engine in C like the CLinker. But I think the call to streamline does it. If the loop over each operation is in Python, the Python overhead will kill the performance. But I'm not convinced the current version will give speed up vs Theano as JAX use XLA on GPU from my knowledge (I work on XLA). But it looks a like a good proof of concept. |
All right, I've provided a context that makes it much easier to add more In its current form, it will also compile the entire graph to JAX; however, if a single unsupported Anyway, this whole endeavour seems pretty clear-cut now. For that matter, after we add a conversion for |
This is incredible progress!
Is that because of the call overhead? I would hope we just get a c-extension-like callable like with Cython with minimal call overhead compared to Python. How difficult do you think |
I did some speed comparisons: https://gist.github.com/twiecki/38dc98197eed5594c5518a3971064c92 On small arrays, c mode is about 6x faster than JAX, but with larger arrays it seems to be about the same, which is a great starting point. |
I updated the gist with a simple PyMC3 model, we can just pass the jax mode to the
I think iteratively solving those individual problems until it compiles will be a good route to take. |
No, it's just that I haven't set up the logic for computing subgraphs (and determining which to actually use).
I think it could be annoying to fully implement, but—at the end of the day—it's mostly a matter of porting the Python version of
Awesome, much appreciated!
That seems to make sense to me. Where do we think we'd notice the most speed-up? Graphs that use loops that can be parallelized/fused? I don't have a GPU on hand to try, so can't do any testing with that.
Yes, that was my plan, and it should be pretty straightforward. We can also go through the list of |
Tested @twiecki's gist, seems jax primitive like jax.vmap, jax.grad does not yet work: jax.grad(theano_jax_fn)(*test_input_vals)
test_input_vals2 = [
np.tile(np.arange(1000).astype(float), (1, 1000, 1)),
np.tile(np.arange(1000).astype(float), (1, 1000, 1)),
]
jax.vmap(theano_jax_fn)(*test_input_vals2) gives full trace---------------------------------------------------------------------------
Exception Traceback (most recent call last)
<ipython-input-30-241055d0c563> in <module>
4 ]
5
----> 6 jax.vmap(theano_jax_fn)(*test_input_vals2)
~/miniconda3/lib/python3.7/site-packages/jax/api.py in batched_fun(*args)
871 _ = _mapped_axis_size(in_tree, args_flat, in_axes_flat, "vmap")
872 out_flat = batching.batch(flat_fun, args_flat, in_axes_flat,
--> 873 lambda: flatten_axes("vmap out_axes", out_tree(),
874 out_axes))
875 return tree_unflatten(out_tree(), out_flat)
~/miniconda3/lib/python3.7/site-packages/jax/interpreters/batching.py in batch(fun, in_vals, in_dims, out_dim_dests)
32 # executes a batched version of `fun` following out_dim_dests
33 batched_fun = batch_fun(fun, in_dims, out_dim_dests)
---> 34 return batched_fun.call_wrapped(*in_vals)
35
36 @lu.transformation_with_aux
~/miniconda3/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
148 gen = None
149
--> 150 ans = self.f(*args, **dict(self.params, **kwargs))
151 del args
152 while stack:
~/miniconda3/lib/python3.7/site-packages/theano/compile/function_module.py in __call__(self, *args, **kwargs)
811 s.storage[0] = s.type.filter(
812 arg, strict=s.strict,
--> 813 allow_downcast=s.allow_downcast)
814
815 except Exception as e:
~/miniconda3/lib/python3.7/site-packages/theano/tensor/type.py in filter(self, data, strict, allow_downcast)
148 # data has to be converted.
149 # Check that this conversion is lossless
--> 150 converted_data = theano._asarray(data, self.dtype)
151 # We use the `values_eq` static function from TensorType
152 # to handle NaN values.
~/miniconda3/lib/python3.7/site-packages/theano/misc/safe_asarray.py in _asarray(a, dtype, order)
32 dtype = theano.config.floatX
33 dtype = np.dtype(dtype) # Convert into dtype object.
---> 34 rval = np.asarray(a, dtype=dtype, order=order)
35 # Note that dtype comparison must be done by comparing their `num`
36 # attribute. One cannot assume that two identical data types are pointers
~/miniconda3/lib/python3.7/site-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
83
84 """
---> 85 return array(a, dtype, copy=False, order=order)
86
87
~/miniconda3/lib/python3.7/site-packages/jax/core.py in __array__(self, *args, **kw)
448 "JAX Tracer instance; in that case, you can instead write "
449 "`jax.device_put(x)[idx]`.")
--> 450 raise Exception(msg)
451
452 def __init__(self, trace: Trace):
Exception: Bad input argument to theano function with name "<ipython-input-3-da3732ec7fad>:10" at index 0 (0-based).
Backtrace when that variable is created:
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2858, in run_cell
raw_cell, store_history, silent, shell_futures)
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 2886, in _run_cell
return runner(coro)
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
coro.send(None)
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3063, in run_cell_async
interactivity=interactivity, compiler=compiler, result=result)
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3254, in run_ast_nodes
if (await self.run_code(code, result, async_=asy)):
File "/home/junpenglao/miniconda3/lib/python3.7/site-packages/IPython/core/interactiveshell.py", line 3331, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-a27f936a3f77>", line 4, in <module>
x = tt.matrix('x')
The numpy.ndarray conversion method __array__() was called on the JAX Tracer object Traced<ShapedArray(float32[1000,1000])>with<BatchTrace(level=2/0)>
with val = array([[[ 0., 1., 2., ..., 997., 998., 999.],
[ 0., 1., 2., ..., 997., 998., 999.],
[ 0., 1., 2., ..., 997., 998., 999.],
...,
[ 0., 1., 2., ..., 997., 998., 999.],
[ 0., 1., 2., ..., 997., 998., 999.],
[ 0., 1., 2., ..., 997., 998., 999.]]])
batch_dim = 0.
This error can occur when a JAX Tracer object is passed to a raw numpy function, or a method on a numpy.ndarray object. You might want to check that you are using `jnp` together with `import jax.numpy as jnp` rather than using `np` via `import numpy as np`. If this error arises on a line that involves array indexing, like `x[idx]`, it may be that the array being indexed `x` is a raw numpy.ndarray while the indices `idx` are a JAX Tracer instance; in that case, you can instead write `jax.device_put(x)[idx]`. |
@junpenglao, I didn't know one could insert folds into GitHub comments!
Yeah, in order to implement I had originally thought to use |
Ah that make sense - In that case would it be possible have a function to output a none jit function? |
Yeah, you can construct a Theano |
Closes #10. Well, at least the JAX part, but Cython and Numba implementations can follow very similar approaches and accomplish the same thing.
Here's an old example of a Numba-enabled Theano
Op
. We can most likely do something similar for Cython and JAX, as well.As an extension to that example, it would be nice to have the compiled function be the
Op.perform
method itself; that way, we could attempt to compile existingOp
s with little-to-no changes.The text was updated successfully, but these errors were encountered: