Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ATen Distributions.cu + Poisson #58

Closed
wants to merge 7 commits into
base: master
from

Conversation

Projects
None yet
4 participants
@rachtsingh
Copy link

rachtsingh commented Dec 31, 2017

I got a bit into the weeds and the scope expanded, but I think this is the right way to do it. This PR (will) implement a pointwise Poisson sampling method for CPU/CUDA, and should be followed up quickly with ports of the Gamma / Dirichlet samplers, and other things we might want to use from distributions.c.

So far:

  • Renames Generator inside THCRandomTensor.h to be THCGenerator, and fix CUDAGenerator instantiation
  • Implement Poisson distribution + torch.poisson (Variable only for now, but we can wrap it).
  • Write tests for the distribution.
  • Implement a Philox-based sampler for rejection sampling.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from b0f67da to 00f3abd Jan 2, 2018

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 2, 2018

Ah, not quite ready for review yet, by the way - I still need to debug this CUDA issue and write tests.

I tried refactoring to use ATen instead (which is much cleaner), but I can't figure out how to sample a uniform given a CUDAGenerator reference (which doesn't seem to have a reference to any curandState*-types). If anyone has a pointer that would be great.

@fritzo
Copy link
Collaborator

fritzo left a comment

Note you should register this in the EXAMPLES list in test_distributions.py. When you do so, tests will fail because you do not implement .entropy(). If you don't want to implement .entropy() (which as I understand has no closed form for Poisson), you should allow NotImplementedError in TestDistributionShapes.test_entropy_shape():

- actual_shape = dist.entropy().size()
+ try:
+     actual_shape = dist.entropy().size()
+ except NotImplementedError:
+     continue
result[i] = y; \
} \
} \
#define GENERATE_KERNEL1(NAME, T, ARG1, CURAND_T, CURAND_FUNC, TRANSFORM) \

This comment has been minimized.

@fritzo

fritzo Jan 2, 2018

Collaborator

nit: Consider reverting unnecessary whitespace changes to ease review by PyTorch folks.

This comment has been minimized.

@rachtsingh

rachtsingh Jan 3, 2018

Author

Thanks! I usually squash before review but I'll be more on top of it.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from 00f3abd to 420291b Jan 3, 2018

@rachtsingh rachtsingh changed the title Implementation of Poisson distribution + torch.poisson(...) Implement ATen Distributions.cu + Distributions.cpp Jan 3, 2018

@gchanan

This comment has been minimized.

Copy link

gchanan commented Jan 3, 2018

On your question, I don't think we should add calls to get random numbers directly to the generator; the normal pattern is to pass in the generator to a function that generates the random numbers, right?

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 3, 2018

Yeah, that sounds right. I'm not sure about casting the result of unsafeGetTH from void * to be able to sample though - is it ok to change the generator to output the specialized Generator when parsing native_functions.yaml?

EDIT: taken care of, I didn't realize how at::globalContext works. I poked around with CUDA pointers for a bit and realized it's the same as the type's context, at least on my machine. Let me know if there's a multi-GPU issue.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from 420291b to c45daf3 Jan 6, 2018

@rachtsingh rachtsingh changed the title Implement ATen Distributions.cu + Distributions.cpp Implement ATen Distributions.cu + Poisson Jan 6, 2018

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from c45daf3 to 5c1251c Jan 6, 2018

//g if (!getApplyGrid(totalElements, grid)) {
//g return false;
//g }
//g grid = dim3(1);

This comment has been minimized.

@rachtsingh

rachtsingh Jan 6, 2018

Author

This is temporary - looking for advice on how to best limit the blockDim to 256 when calling CUDA_tensor_apply2

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 6, 2018

Ok, I think it's probably ready for review. @fritzo, given the C++/CUDA changes, it's probably best to ask someone from PyTorch to review as well? I don't know what exactly the plans are for CUDAGenerator, etc.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from 5c1251c to 4ff1328 Jan 7, 2018

@fritzo
Copy link
Collaborator

fritzo left a comment

Python code looks good. just one comment on parameter name. I have not reviewed CUDA code; let's do that review on pytorch/pytorch.


class Poisson(Distribution):
r"""
Creates a Poisson distribution parameterized by `_lambda`, the rate parameter.

This comment has been minimized.

@fritzo

fritzo Jan 8, 2018

Collaborator

nit: preceding underscore usually denotes a private variable. It's probably safer to call this lam or lambda_. Alternatively it would be nice to call it rate which is both semantically meaningful and improves compatibility with tensorflow distributions.

This comment has been minimized.

@fritzo

fritzo Jan 8, 2018

Collaborator

For example self._lambda would be the obvious member name but that looks private. Instead self.lambda_ or self.rate are clearly public.

This comment has been minimized.

@rachtsingh

rachtsingh Jan 8, 2018

Author

Both sound good to me. Let's use rate since that's more evocative? I.e. like mean/std in Normal already, and matches up with Exponential, which is the right idea.

@fritzo

This comment has been minimized.

Copy link
Collaborator

fritzo commented Jan 8, 2018

@apaszke How should we test the CUDA sampler? Do we use the same tests in test_distributions.py that we use for CPU samplers, or should these new tests live in test_cuda.py?

@apaszke

This comment has been minimized.

Copy link

apaszke commented Jan 8, 2018

I think test_distributions.py is ok. Just make unittest skip the tests if CUDA is unavailable (look at test_nn.py and test_autograd.py for examples).

} // at::native::dist

Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
Tensor ret = lambda.type().toScalarType(kDouble).zeros(lambda.sizes());

This comment has been minimized.

@apaszke

apaszke Jan 8, 2018

It's better to use tensor instead of zeros if you don't depend on the initial values of the tensor (it will be uninitialized)


Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
Tensor ret = lambda.type().toScalarType(kDouble).zeros(lambda.sizes());
auto lambda_ = lambda.toType(ScalarType::Double);

This comment has been minimized.

@apaszke

apaszke Jan 8, 2018

Is this unstable in fp32? fp64 math is extremely slow on non-Tesla GPUs, and we should avoid it

* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/

This comment has been minimized.

@apaszke

apaszke Jan 8, 2018

Is this the right place for this copyright notice?

@@ -750,6 +751,7 @@ static PyMethodDef TorchMethods[] = {
{"_standard_gamma", (PyCFunction)THPModule_standard_gamma, METH_VARARGS | METH_KEYWORDS, NULL},
{"_dirichlet_grad", (PyCFunction)THPModule_dirichlet_grad, METH_VARARGS | METH_KEYWORDS, NULL},
{"bernoulli", (PyCFunction)THPModule_bernoulli, METH_VARARGS | METH_KEYWORDS, NULL},
{"poisson", (PyCFunction)THPModule_poisson, METH_VARARGS | METH_KEYWORDS, NULL},

This comment has been minimized.

@apaszke

apaszke Jan 8, 2018

Let's avoid adding more of the sampling methods to the global namespace. torch.distributions is the way to expose those samplers

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 8, 2018

Ok, think I addressed comments (thanks for the review!), and am removing poisson from the global namespace [running tests locally right now]. I'll move this PR to pytorch/pytorch afterwards.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch 3 times, most recently from 8bd52eb to 3d39052 Jan 9, 2018

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from 3b11b8b to 9a677c3 Jan 9, 2018

} // at::native::dist

Tensor _s_poisson_cuda(const Tensor& lambda, Generator* gen) {
Tensor ret = lambda.type().tensor(lambda.sizes());
auto lambda_ = lambda.toType(ScalarType::Float);
dispatch_all<void, dist::PoissonOpCUDA>(ret.type(), "poisson", ret, lambda_, dist::get_states(gen));
dispatch_floating_types<void, dist::PoissonOpCUDA>(ret.type(), "poisson", ret, lambda_, dist::get_states(gen));

This comment has been minimized.

@gchanan

gchanan Jan 9, 2018

if you are having trouble with half, try DISPATCH_ALL_FLOATING_TYPES; dispatch_all uses Half and DISPATCH_ALL_FLOATING_TYPES uses half; we should resolve this, probably always using Half (and providing device conversion functions?). CC @colesbury

This comment has been minimized.

@rachtsingh

rachtsingh Jan 9, 2018

Author

Thanks, that works!

This comment has been minimized.

@apaszke

apaszke Jan 9, 2018

Ideally we would always use half for device functions and at::Half for host functions.

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 9, 2018

After doing some checks, it looks like curand_poisson without using the more robust API uses either an inverse CDF sampler or the transformed rejection method, neither of which are accurate for low values of lambda. So, it fails the distribution tests (good thing they work!).

I think there's 3 options - (1) to merge as-is (by raising the threshold on the CUDA version), but to warn that for low values of lambda, the results may be biased, or (2) merge the CPU version, not the GPU version, or (3) punt on this until we get Philox-based PRNGs up, which was what I was planning to tackle next.

@apaszke

This comment has been minimized.

Copy link

apaszke commented Jan 9, 2018

I'm fine with all options. I'll leave the decision to @fritzo

@fritzo

This comment has been minimized.

Copy link
Collaborator

fritzo commented Jan 9, 2018

Option 1 sounds very reasonable. You could also file an issue at probtorch/pytorch noting that we should revisit this.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from b1acc71 to 5a7d114 Jan 15, 2018

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 15, 2018

OK, fixed the issues and added a Philox-based RNG. I don't think the sampling algorithm is the most accurate, but we can work on it in the future I think.

Also, I think there's a bug in the handling of failure_rate (note that I decreased that to make the tests pass), but I'll have to double check the math before fixing.

Let me know if this is good, I'll squash the commits to be 1 (or whatever you think is semantic), and then update the pytorch PR.

@rachtsingh rachtsingh force-pushed the rachtsingh:poisson branch from fc385ab to 1a6d1e5 Jan 16, 2018

@rachtsingh

This comment has been minimized.

Copy link
Author

rachtsingh commented Jan 18, 2018

(discussing in pytorch#4556)

@rachtsingh rachtsingh closed this Jan 18, 2018

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.