add translate rule for standard_gamma#30
add translate rule for standard_gamma#30fehiepsi wants to merge 1 commit intopyro-ppl:masterfrom fehiepsi:gamma1
Conversation
| lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha)) | ||
|
|
||
|
|
||
| @partial(jit, static_argnums=(2, 3)) |
There was a problem hiding this comment.
@neerajprad It is really tricky. Without this line, I get an error said that "numpy.array is not hashable". I'm not sure why your primitive implementation does not require jit. It took me a whole day to debug without success. Luckily, adding this solves that problem. >"<
There was a problem hiding this comment.
I think this is fine, since shape and dtype probably will not vary too much within a single program. I do want to see where this gets triggered though, so I'll try to take a look at this.
There was a problem hiding this comment.
@fehiepsi - I don't see any test failures if I replace this with jit. Could you put in an xfailing test or just paste an example with the error that you noticed?
There was a problem hiding this comment.
@neerajprad Sorry, I mean that I can't use this function without jit. >"<
There was a problem hiding this comment.
I see - the issue is that you have a number of constants of type numpy.ndarray in your trace instead of DeviceArray, and numpy objects aren't hashable. It seems to me that there are certain functions / operations within the code that end up constructing original numpy arrays, but I haven't yet tracked where this is actually happening.
There was a problem hiding this comment.
That makes sense. Let's keep this phenomenon in mind if we implement more primitives and get this issue again. :)
|
|
||
| # TODO: use lax.cond here | ||
| # boost for the case alpha < 1 | ||
| boost = np.where(alpha >= 1.0, 1.0, random.uniform(key, ()) ** (1.0 / alpha)) |
There was a problem hiding this comment.
this is cheap comparing to the others; so using np.where here is fine
Sounds good! I'll play around with this in the meantime. |
|
Closed! I'll open a new PR because I can't push to this branch (it becomes unknown...). |
This PR adds translation rule for
standard_gammaand also makes it faster. Now,standard_gammais jittable and the speed of its jitted version is comparable to pytorch's one. The remaining issue is to avoid the condition involvinglog+log, which I don't have a solution for it yet (usingjax.conddoes not seem to solve that problem, as discussed in jax-ml/jax#415).@neerajprad Because this needs the master branch of JAX (in particularly, jax-ml/jax#495), let's wait for a new release of jax before merging.