Skip to content

add translate rule for standard_gamma#30

Closed
fehiepsi wants to merge 1 commit intopyro-ppl:masterfrom
fehiepsi:gamma1
Closed

add translate rule for standard_gamma#30
fehiepsi wants to merge 1 commit intopyro-ppl:masterfrom
fehiepsi:gamma1

Conversation

@fehiepsi
Copy link
Member

This PR adds translation rule for standard_gamma and also makes it faster. Now, standard_gamma is jittable and the speed of its jitted version is comparable to pytorch's one. The remaining issue is to avoid the condition involving log+log, which I don't have a solution for it yet (using jax.cond does not seem to solve that problem, as discussed in jax-ml/jax#415).

@neerajprad Because this needs the master branch of JAX (in particularly, jax-ml/jax#495), let's wait for a new release of jax before merging.

@fehiepsi fehiepsi requested a review from neerajprad March 12, 2019 00:24
lambda tangent, sample, key, alpha, **kwargs: tangent * _standard_gamma_grad(sample, alpha))


@partial(jit, static_argnums=(2, 3))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad It is really tricky. Without this line, I get an error said that "numpy.array is not hashable". I'm not sure why your primitive implementation does not require jit. It took me a whole day to debug without success. Luckily, adding this solves that problem. >"<

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is fine, since shape and dtype probably will not vary too much within a single program. I do want to see where this gets triggered though, so I'll try to take a look at this.

Copy link
Member

@neerajprad neerajprad Mar 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fehiepsi - I don't see any test failures if I replace this with jit. Could you put in an xfailing test or just paste an example with the error that you noticed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad Sorry, I mean that I can't use this function without jit. >"<

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see - the issue is that you have a number of constants of type numpy.ndarray in your trace instead of DeviceArray, and numpy objects aren't hashable. It seems to me that there are certain functions / operations within the code that end up constructing original numpy arrays, but I haven't yet tracked where this is actually happening.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Let's keep this phenomenon in mind if we implement more primitives and get this issue again. :)


# TODO: use lax.cond here
# boost for the case alpha < 1
boost = np.where(alpha >= 1.0, 1.0, random.uniform(key, ()) ** (1.0 / alpha))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is cheap comparing to the others; so using np.where here is fine

@neerajprad
Copy link
Member

@neerajprad Because this needs the master branch of JAX (in particularly, jax-ml/jax#495), let's wait for a new release of jax before merging.

Sounds good! I'll play around with this in the meantime.

@neerajprad neerajprad changed the title [blocked] add translate rule for standard_gamma add translate rule for standard_gamma Mar 13, 2019
@fehiepsi fehiepsi added WIP and removed blocked labels Mar 25, 2019
@fehiepsi
Copy link
Member Author

Closed! I'll open a new PR because I can't push to this branch (it becomes unknown...).

@fehiepsi fehiepsi closed this Mar 26, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants