Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wishart / InverseWishart / LKJ priors #1692

Closed
rachtsingh opened this issue Jan 1, 2019 · 42 comments
Closed

Wishart / InverseWishart / LKJ priors #1692

rachtsingh opened this issue Jan 1, 2019 · 42 comments
Assignees

Comments

@rachtsingh
Copy link

Would it be possible to implement Wishart / InverseWishart / LKJ priors?

gpytorch has them already, but when I tried mixing in TorchDistributionMixin to get something useable in Pyro, I realized that they don't have a .sample method.

I don't think it's easy to get efficient samplers for the InverseWishart (I think trying to build it via a TransformedDistribution might be too slow), but curious to see other approaches. There's a Tensorflow Probability tutorial on this here, which explains some of the underlying ideas, as well as a note here.

Great work by the way! Pyro is so easy to use it's incredible.

@eb8680
Copy link
Member

eb8680 commented Jan 1, 2019

You actually don't need a .sample method for a distribution unless you want to sample from it - see e.g. pyro.distributions.VonMises. Does your use case require sampling from the prior? Were you getting an error when you tried to use TorchDistributionMixin with the GPyTorch versions?

@rachtsingh
Copy link
Author

I was trying to use HMC (NUTS), which I think requires sampling (I do get an NotImplementedError). On second thought, though this might belong in torch.distributions instead of here, so feel free to close.

@eb8680
Copy link
Member

eb8680 commented Jan 1, 2019

I was trying to use HMC (NUTS), which I think requires sampling

That's just for setting the initial value - you can hack around that for now by defining a dummy sample method for your prior that just returns an arbitrary appropriately shaped value that's in the prior's support.

@rachtsingh
Copy link
Author

Got it. I just tried it with the patched method, and ran into two issues:

  1. We can't merge GPyTorch Prior classes and TorchDistributionMixin since the former subclasses nn.Module but doesn't implement forward. So I just recreated the distribution by copying their code (see below).
  2. This doesn't work either because:
NotImplementedError: Cannot transform _PositiveDefinite constraints

which makes sense.

Actually modeling-wise the LKJ prior would be more useful, but I tried that and it also runs into the same error.

Here's my attempt:

class LKJCorr(Distribution, TorchDistributionMixin):
    arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
    support = constraints.positive_definite
    _validate_args = True

    def __init__(self, n, eta, validate_args=False):
        if not isinstance(n, int) or n < 1:
            raise ValueError("n must be a positive integer")
        if isinstance(eta, Number):
            eta = torch.tensor(float(eta))
        self.n = torch.tensor(n, dtype=torch.long, device=eta.device)
        batch_shape = eta.shape
        event_shape = torch.Size([n, n])
        i = torch.arange(n, dtype=eta.dtype, device=eta.device)
        C = (((2 * eta.view(-1, 1) - 2 + i) * i).sum(1) * math.log(2)).view_as(eta)
        C += n * torch.sum(2 * torch.lgamma(i / 2 + 1) - torch.lgamma(i + 2))
        self.eta = eta
        self.C = C
        super(LKJCorr, self).__init__(batch_shape, event_shape, validate_args=validate_args)

    def log_prob(self, X):
        if any(s != self.n for s in X.shape[-2:]):
            raise ValueError("Correlation matrix is not of size n={}".format(self.n.item()))
        if not _is_valid_correlation_matrix(X):
            raise ValueError("Input is not a valid correlation matrix")
        log_diag_sum = torch.stack([p.cholesky(upper=True).diag().log().sum() for p in X.view(-1, *X.shape[-2:])])
        return self.C + (self.eta - 1) * 2 * log_diag_sum
    
    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        return torch.eye(self.n.item()).expand(shape)

class LKJCov(Distribution, TorchDistributionMixin):
    arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
    support = constraints.positive_definite
    _validate_args = True
    
    def __init__(self, n, eta, sd_prior, validate_args=False):
        correlation_prior = LKJCorr(n=n, eta=eta, validate_args=validate_args)
        self.correlation_prior = correlation_prior
        self.sd_prior = sd_prior
        super(LKJCov, self).__init__(self.correlation_prior._batch_shape, 
                                     self.correlation_prior._event_shape, 
                                     self.correlation_prior._validate_args)

    def log_prob(self, X):
        marginal_var = torch.diagonal(X, dim1=-2, dim2=-1)
        if not torch.all(marginal_var >= 0):
            raise ValueError("Variance(s) cannot be negative")
        marginal_sd = marginal_var.sqrt()
        sd_diag_mat = _batch_form_diag(1 / marginal_sd)
        correlations = torch.matmul(torch.matmul(sd_diag_mat, X), sd_diag_mat)
        log_prob_corr = self.correlation_prior.log_prob(correlations)
        log_prob_sd = self.sd_prior.log_prob(marginal_sd)
        return log_prob_corr + log_prob_sd

    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        return torch.eye(self.correlation_prior.n.item()).expand(shape)

@fehiepsi
Copy link
Member

fehiepsi commented Jan 2, 2019

How about using LKJ prior for Cholesky? From Stan reference LKJ prior for cov will make "the code to run slower and consume more memory with more risk of numerical errors". I think that LowerCholeskyTransform has inverse method. There still needs an implementation for .log_abs_det_jacobian() method. Then it is ready for using in LKJ Cholesky distribution.

@rachtsingh
Copy link
Author

Huh, that sounded right, but I think I'm missing something. Shouldn't the domain of that transform be constraints.positive_definite? And I don't think it should call torch.tril, but instead torch.cholesky (which now supports batching).

When that's figured out, the implementation looks like:

class LKJCholesky(LKJCorr):
    support = constraints.lower_cholesky

    def log_prob(self, L):
        log_diag_sum = torch.diagonal(L, dim1=-2, dim2=-1).log().sum(-1)
        return self.C + (self.eta - 1) * 2 * log_diag_sum

class LKJCholeskyCov(Distribution, TorchDistributionMixin):
    arg_constraints = {"n": constraints.positive_integer, "eta": constraints.positive}
    support = constraints.lower_cholesky
    
    def __init__(self, n, eta, sd_prior, validate_args=False):
        self.correlation_prior = LKJCholesky(n=n, eta=eta, validate_args=validate_args)
        self.sd_prior = sd_prior
        super(LKJCholeskyCov, self).__init__(self.correlation_prior._batch_shape, 
                                     self.correlation_prior._event_shape, 
                                     self.correlation_prior._validate_args)
    
    def log_prob(self, L):
        # we essentially have the (LD^{1/2}) part of the LDL decomposition
        marginal_std = torch.diagonal(X, dim1=-2, dim2=-1)
        sd_diag_mat = _batch_form_diag(1 / marginal_std)
        correlation_L = torch.matmul(L, sd_diag_mat) # or is that backwards??
        log_prob_corr = self.correlation_prior.log_prob(correlation_L)
        log_prob_sd = self.sd_prior.log_prob(marginal_std)
        return log_prob_corr + log_prob_sd
    
    def sample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        return torch.eye(self.correlation_prior.n.item()).expand(shape)

@fritzo
Copy link
Member

fritzo commented Jan 2, 2019

I'm missing something. Shouldn't the domain ...

https://github.com/pytorch/pytorch/blob/master/torch/distributions/transforms.py#L514 transforms a matrix with garbage in the upper triangle to a matrix with zeros in the upper triangle, hence the use of torch.tril rather than torch.cholesky. You should never need to directly access the unconstrained domain (except indirectly via a Pyro optimizer or HMC). To get a positive definite matrix, you should be able to define x=my_transform(u) as the constrained parameter, and set cov=x.mm(x.t()).

@fehiepsi
Copy link
Member

fehiepsi commented Jan 3, 2019

At one day, we should find a clean way to define unconstrained domain (maybe vector instead of matrix) for LowerCholesky Transform. It will reduce the number of parameters to optimize for these priors.

@elbamos
Copy link
Contributor

elbamos commented Jan 27, 2019

Just tossing in a +1 on this... I'm experimenting with pyro by converting a Stan model, and the lack of priors for covariance matrices is kinda an impediment.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 27, 2019

I guess the most complicated work for LKJ prior (which is more numerically stable than Wishart/InverseWishart) is to define a transform from unconstrained space to the space of correlation matrices. Stan reference gives a nice derivation for such transform, which based on the paper: https://www.sciencedirect.com/science/article/pii/S0047259X09000876. The tricky part (which requires loops) is to transform the canonical partial correlation to the Cholesky of correlation (transform from z to x in Stan reference). It is a simple math but I don't know how to code it efficiently to support batches. Maybe it is good to start with a "loop" version if someone is interested in.

@elbamos
Copy link
Contributor

elbamos commented Jan 28, 2019

Does it need to be coded to efficiently support batches? If someone has a template (i.e., that shows what functions need to be filled in), I can help.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 28, 2019

@elbamos Here is a template which I come up with:

Step 1: define constraint as in this script

from torch.distributions.constraints import Constraint

class _CorrelationMatrix(Constraint):
    def check(self, value):
        # check if `value` is positive definite, diagonal is equal to 1, and each entries lie in [-1, 1]

# Public interface.
correlation_matrix = _CorrelationMatrix()

Step 2: define transform as in this script

from torch.distributions.transforms import Transform

class CorrelationMatrixTransform():
    domain = constraints.real
    codomain = constraints.correlation_matrix
    bijective = True
    sign = +1

    def __eq__(self, other):
        return isinstance(other, CorrelationMatrixTransform)

    def _call(self, x):
        # compute y from the vector x as in Stan reference:
        # https://mc-stan.org/docs/2_18/reference-manual/correlation-matrix-transform-section.html
        return y

    def _inverse(self, y):
        # see the above reference
        return x

    def log_abs_det_jacobian(self, x, y):
        # see the above reference
        return log_abs_det

Step 3: register bijective as in this script

Step 4: define distribution (as in PyTorch/Pyro distributions)

  • Sample method: See the algorithm in the section 2.4. of https://www.sciencedirect.com/science/article/pii/S0047259X09000876. The first step is to use Beta distribution to generate partial correlation, which is a vector in the interval [-1, 1]. Then use the correlation matrix transform to transform it to a correlation matrix.
  • Log prob: See Stan reference for unnormalized probability and the section 3.3 of the above paper on how to compute normalization constant.

Hope that help!

@elbamos
Copy link
Contributor

elbamos commented Jan 29, 2019 via email

@fritzo
Copy link
Member

fritzo commented Jan 29, 2019

Don't we only need the transformation from an unconstrained vector to the lower Cholesky

@elbamos I think what you're describing are arbitrary covariance matrices. Correlation matrices have the additional constraint that the diagonal is all-ones and off diagonal entries are all in [-1,1].

EDIT While I don't think we can use PyTorch's lower_cholesky constraint directly, I think we could translate the stan parametrization to PyTorch. @elbamos thanks for the pointers!

@elbamos
Copy link
Contributor

elbamos commented Jan 29, 2019

@fritzo and @fehiepsi: The most convenient way to work with covariance matrices, in practice, is usually by separately generating a scale vector (theta), and the lower cholesky factor of a correlation matrix (Omega).

If you multiply diag(sqrt(theta)) * Omega, you get the lower cholesky factor of a covariance matrix, which is the most efficient parameterization for multivariate distributions.

So what's missing from pyro is are distributions for generating the lower cholesky factor of a correlation matrix (probably by the LKJ prior).

The question is what Transforms and Constraints need to be coded.

Ordinarily in HMC, sampling takes place in an unconstrained space. Variables are then transformed into a constrained space. If that's how Pyro works too, what we'd need is an (invertible) transformation from the unconstrained space of (d * (d-1))/2 length vectors into the space of lower cholesky factors of correlation matrices.

Since the Stan manual helpfully provides an algorithm for that transformation, I'm not sure what Constraint actually needs to be coded. It seems that all we need is the Transform.

Does this help explain it?

If I'm right about what's required here, then I've already got a prototype implementation, based on the Stan source code, of everything except log_prob. (If its necessary to backprop through the Transform, this becomes tricky. It is probably actually easier to implement the Transform as a pytorch function so we can provide a custom grad function.)

@fritzo
Copy link
Member

fritzo commented Jan 29, 2019

@elbamos thanks for explaining, yes that makes sense. I was confused because torch.distributions.constraints.lower_cholesky transforms (in a non-bijective way) a matrix into the lower Cholesky decomposition of an arbitrary positive definite matrix, rather than of a correlation matrix.

I think you're right, we can follow Stan code to develop a new Transform, and register a new constraint called correlation, similar to positive_definite. @fehiepsi I believe this will need to unroll all n(n+1)/2 parameters of the n x n matrix into a vector, so event_shape=(n*(n+1)/2,); do you know a trick to do that, and could you help @elbamos . @elbamos I still believe you should be able to follow the steps @fehiepsi suggests.

@elbamos For context PyTorch's transforms differ from Stan's and Tensorflow's in that they also include non-bijective transforms including projections. These are often cheaper and more stable than bijective transforms. While HMC and NUTS require bijections, SVI and MAP inference can allow such overparameterized transforms. This is why lower_cholesky simply zeros out the upper-triangle: it is sometimes cheaper to overparameterize and ignore extra parameters than to perform the gather operation required to use exactly the right number of parameters.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 29, 2019

@elbamos It is great to hear that you already come up with a prototype (I was intending to make this implementation)! I agree that LKJ prior for Cholesky is enough. LKJ prior for correlation (which is more popular for small models) will be mostly based on its Cholesky version (except the log_prob method, where we have to do Cholesky transform for a correlation matrix, which is ineffective for large matrix), so we can do it later if necessary. To make our discussion consistence, I will only discuss about LKJ prior for Cholesky.

Here is the template of LKJ prior for Cholesky (quite similar to the correlation version):

from torch.distributions.constraints import Constraint
from torch.distributions.transforms import Transform

class _LowerCholeskyCorr(Constraint):
    def check(self, value):
        # check if `value` diagonal is positive, and squared norm of each row is 1

# Public interface.
lower_cholesky_corr = _LowerCholeskyCorr()

class LowerCholeskyCorrTransform():
    domain = constraints.real
    codomain = constraints.lower_cholesky_corr
    # after this, it is similar to correlation version

You are right about how Pyro's HMC works. About constraint, it is required to define support for each distribution. Under the hood, HMC will first see what is the constraint of the support of the distribution of a latent variable. Then it will see if there is a bijective (step 3 in my template) to that constraint. In this case, if it sees that the constraint is lower_cholesky_corr, then it will use LowerCholeskyCorrTransform as default transform.

Of course, we can just define a transform and specify it in a HMC instance through transform argument without having to worry about constraints. If you still worry about Constraint part, then I'll add it in another PR. :)

@fritzo I don't have a trick to do that. And I think that we can just keep event_shape for the distribution is n x n.

As I mentioned above, we can make a "loop" version as a first step. If we only need to convert a vector to a lower triangular matrix then I know the following way.

# x.shape = N x (N + 1) / 2
y = x.new_ones(N, N).tril(diagonal=0)
y[y > 0] = x

But I don't think that it will help in this situation. I don't know how to make a non-loop version to transform from y to the target (lower cholesky of a correlation matrix). And working with x will be more convenient than working with y I guess. Similarly, for the transform from correlation's cholesky to its unconstrained domain, it might be ineffective to unroll the correlation matrix into a vector as a first step because we'll use "loop" to generate entries of the unconstrained vector.

@elbamos
Copy link
Contributor

elbamos commented Jan 30, 2019

I already have code for unrolling the vector. The issue is whether there will ever be a need to backprop through that code, because there doesn't appear to be a purely vectorized way of doing it. If backprop'ing is going to be necessary, then we will have to write a pytorch function implementing a custom grad.

The code that needs to be written is for the log_prob function.

If I'm understanding you correctly: The function of the constraint is that during sampling, Pyro will automatically select a transform that matches the domains, which is the Transform(s) I've written. (I've actually coded-up two of them, one from the unconstrained space, and one from [-1, 1], the difference being application of tanh.)

I setup a branch, https://github.com/elbamos/pyro/blob/lkj/pyro/distributions/lkj.py

@elbamos
Copy link
Contributor

elbamos commented Jan 30, 2019

Also - I took a look at the code for the lower_cholesky constraint and transform. It appears to me that the lower_cholesky constraint is fine. What's needed is a constraint on the input domain - which needs to consist of reals where the event shape is (n*(n+1)/2,). No?

@fehiepsi
Copy link
Member

@elbamos Could you let me know which lines in your code which you worry about grad's backpropagation? I can't identify it. Overall, your code looks great. It will be better if we modify it to support "batching".

log_prob normalization/constant factor can be found in the above paper by Daniel Lewandowski. Then the translation log_prob of correlation to log_prob of its cholesky can be found in at e.g. stackexchange.

Please let me know which part you need me to add in your code base (to avoid duplicated work). I'm happy to work on this with you.

@elbamos
Copy link
Contributor

elbamos commented Jan 30, 2019

@fehiepsi I've tested the algorithm code in there separately, but I haven't tested any of it in-place in those classes.

What I think will cause backprop problems, is the in-place modification to the tensors when transforming the vector into a matrix.

You're right about the batching... Maybe someone else can modify the code to do that? I just find it very hard to imagine someone generating correlation matrices in batches that way, but I guess other people have uses cases very different from mine.

Regarding the log prob, I think our best source for it is actually the Stan source code... https://github.com/stan-dev/math/blob/master/stan/math/prim/mat/prob/lkj_corr_cholesky_lpdf.hpp

Actually, its very interesting to try to implement the same model in Stan and Pyro. Stan is currently running the model about 100x faster than Pyro, and GPU (mine is a 1080Ti) actually makes things worse rather than better. Presumably this is because my Pyro implementation isn't efficiently written yet.
Someone should write a tutorial on converting Stan programs to Pyro in an optimized way.

@fehiepsi
Copy link
Member

@elbamos Yes, we can use Stan source code to verify our implementation. I'll take care of "batching". If you want to convert vector to matrix, you can use the trick:

# x.shape = N x (N + 1) / 2
y = x.new_ones(N, N).tril(diagonal=0)
y[y > 0] = x

Back-propagation should be fine with this version.

About performance of HMC, we did profiling for various models and observe that most of time is spent for computing potential energy which calls log_prob method, which is out of Pyro scope. Stan did a pretty good job in compiling their code. On the other hand, we rely on pytorch jit to compiling potential energy. Currently, pytorch jit only gives 2x time advantages so we hope that in the future it will be better.

GPU is suitable for large vectors/batches. It seems do not give computation advantages for models we get from statistical textbooks. So you don't have to worry about it for now.

Could you please allow me pushing commit to your branch and add some tests so I can expand it to verify "batching" work correctly? Thanks!

@fritzo
Copy link
Member

fritzo commented Jan 30, 2019

@elbamos would you be willing to start a PR with your partial implementation? I think that would be a better venue for discussion. @fehiepsi and I can push code to (or open PRs against) your branch to help you with boilerplate.

@elbamos
Copy link
Contributor

elbamos commented Jan 31, 2019

@fritzo Should I start the PR now or wait until we're further along? I'm pretty sure you can open PRs against the branch now. I'm not sure what the correct github etiquette is these days.

@fehiepsi Yeah... The model that I'm porting over does a lot of matrix slicing and reassembling, which I had hoped would be faster in pyro because it has more powerful vectorized functions. But its turning out that simple for loops in Stan can perform those matrix slicing operations faster than vectorized pytorch functions.

Interestingly, the per-iteration performance seems to decline over the course of inference. I suspect this is a combination of things. One thing I learned when I was spending a lot of time building pytorch neural networks, is that to get good performance out of it, you have to optimize your code to re-use buffers. Otherwise pytorch spends a lot of its time reallocating and destroying memory, especially on the gpu. I'm wondering if pyro isn't optimizing for that well?

The other thing that I suspect is going on is the limited support for constrained and truncated distributions in pyro. For example, I don't see a way to tell pyro that one of my variables has to be constrained to be positive, and another to be positive-ordered, etc. See the discussion at the bottom here: https://mc-stan.org/docs/2_18/reference-manual/reject-statements-section.html. There's some additional discussion about it in the Stan mailing list. There are two problems. One is that when a distribution is truncated, it doesn't integrate to 1, and if that isn't taken into account it'll confound the posterior. The second, related problem, is that because of this, a Hamiltonian sampler will tend to keep pushing against the improperly enforced constraint. This slows sampling down considerably, and you end up with a posterior bunched-up around the constraint because the sampler keeps trying to explore that part of the space and it can't.

@fehiepsi
Copy link
Member

@elbamos It is fine to me to make PRs to your repo. Could you please open Issues tab, so I can write to-do list there?

About performance, we didn't care about reallocating/destroying stuffs. @neerajprad might have better ideas about it than me.

About truncated distribution, I am not sure if I understand what you mean correctly. We have a PR at probtorch/pytorch#121, in case you want to follow it.

@elbamos
Copy link
Contributor

elbamos commented Jan 31, 2019

@fehiepsi I've enabled the issues tab. I'll have more time on this project this weekend.

@neerajprad
Copy link
Member

Regarding performance, for most models I would expect the GPU to be slower since HMC/NUTS is heavily sequential - we need to take many steps in the integrator, and unless the time saved from parallelizing the gradient computation within each step is significant, we are unlikely to realize any benefits.

Interestingly, the per-iteration performance seems to decline over the course of inference.

Do you see a difference even after warmup? During warmup, the performance might vary as we adapt the step size.

One thing I learned when I was spending a lot of time building pytorch neural networks, is that to get good performance out of it, you have to optimize your code to re-use buffers.

You are right about NN training, but for HMC, this shouldn't be an issue since we cannot deal with mini-batches and transfer the data all at once to the GPU.

@elbamos
Copy link
Contributor

elbamos commented Feb 7, 2019

Just to update folks on the current status of this - I have code up at https://github.com/elbamos/pyro/blob/lkj/pyro/distributions/lkj.py, which runs when isolated, and seems to be correct, passes tests (although I'm working on tests for the derivative and log_prob), etc.

But when I try to run, I get an error that ValueError: Model specification seems incorrect - cannot find a valid trace.

I'd appreciate some advice/suggestions/help on tracking it down - my understanding of the innards of pyro is quite basic, so I'm not quite sure where to start. @fritzo Could I trouble you to take a peek?

Thanks.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 8, 2019

@elbamos Hope that you don't mind if I take care of this issue separately.

@neerajprad
Copy link
Member

But when I try to run, I get an error that `ValueError: Model specification seems incorrect - cannot find a valid trace.

This is an error that is thrown when HMC tries to find an initial trace to begin sampling from. It repeatedly samples from the prior until it finds a trace with a non-nan value for potential energy, and throws this error if it doesn't succeed in 100 trials. My guess is that the log_prob for the samples generated from the prior in your model is NaN or Inf for one of the distributions? If you point me to your code, I could also take a look -- I'm only familiar with HMC though and not the LKJ implementation.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 8, 2019

This is an error that is thrown when HMC tries to find an initial trace to begin sampling from. It repeatedly samples from the prior until it finds a trace with a non-nan value for potential energy, and throws this error if it doesn't succeed in 100 trials.

Yeah, I forgot about it. This should be the case.

@elbamos
Copy link
Contributor

elbamos commented Feb 9, 2019

Thanks, @neerajprad, I was able to trade it to an issue in the conversion that I've resolved. Now dealing with an issue being thrown by mutli_normal while building a good example. I should have it worked through soon.

@fehiepsi
Copy link
Member

fehiepsi commented Feb 9, 2019

Thanks

+1

@elbamos
Copy link
Contributor

elbamos commented Feb 9, 2019

Ok, I've tracked-down the issue that I'm seeing... It relates to an intersection between Pyro's sampler and precision limits on the tanh transformation. I apologize for the lengthy explanation, but I think its necessary.

To build the L Cholesky of a correlation matrix, we have to go from an unconstrained vector of appropriate size to a lower matrix where each row has a unit norm, and each entry on the diagonal is positive. The transformation works on a row-by-row basis. It fills in the row up to the diagonal. The diagonal element is then filled-in by whatever is necessary so the row has unit norm.

This works fine as long as the input is in the range (-1, 1). If, however, there are many (I haven't tried to figure out the minimum required) elements in the input that are exactly -1 or 1, then there won't be any residual left to put into the diagonal, and the transformation won't create a proper lower cholesky factor.

The first step in the transformation is therefore to take the input and pass it through tanh() so its in the range (-1, 1).

The problem that is arising is that when Pyro samples from the unconstrained space, it apparently begins by sampling from a very wide range. For example, in the test I just ran, the first unconstrained values provided by Pyro were [-39.7456, -44.9190, 76.7859]. In all of the tests I've run, Pyro's first unconstrained values were of similar or greater magnitude. (By comparison, Stan begins sampling uniformly from [-2, 2].)

On values in this range, at double precision, pytorch's tanh produces output in [-1, 1], rather than (-1, 1). The transformation then fails.

In fact, in every test I've run so far, Pyro's initial unconstrained sample produced vectors that, after tanh transformation, were all exactly -1 or 1.

Advice? Suggestions?

@fehiepsi
Copy link
Member

fehiepsi commented Feb 9, 2019

Not sure if it helps, but you can clamp the output of tanh transform to (-1 + eps, 1 - eps) as in https://github.com/pytorch/pytorch/blob/master/torch/distributions/utils.py#L77.

@elbamos
Copy link
Contributor

elbamos commented Feb 9, 2019

@fehiepsi It will run if I do that. But I'd think that would cause two problems:

  1. The clamp will cause a step in the gradient, which will then confuse HMC.
  2. The generated matrices won't be good samples; Pyro will still be feeding-in very wide unconstrained samples, and the inputs to the transformation will all be \in {-1 + eps, 1-eps}, which obv aren't good starting points for correlation matrices.
  3. As I just learned in applying tests, It breaks the inverse transform.

@neerajprad Do you think @fehiepsi 's solution will cause problems in the HMC/NUTS geometry? How big a problem is it if the inverse transform is broken? If the solution works, then I think this is basically done, and I'll make the PR. If not, we need another solution.

@neerajprad
Copy link
Member

@elbamos - Great job debugging this! I think at @fehiepsi's solution should be fine. We have had to add these epsilon factors inside many distributions for numerical stability reasons around boundaries. The only thing I would suggest is to choose eps using torch.finfo as in @fehiepsi's link so that the epsilon added depends on whether the tensor is of type double or float. Another suggestion would be to have the inverse mapping as torch.log1p((2*x)/(1-x)) / 2 which is more numerically stable around 0 though I am not sure if that will help with the particular problem that you are facing.

Do you think @fehiepsi 's solution will cause problems in the HMC/NUTS geometry?

My guess is that we will be sampling extreme values during the initial stages when we take large steps and adjust the step size and mass matrix, and not when we enter the typical set. I think that this should be safe because the Metropolis correction step would end up correcting for this (by rejecting most such proposals at the boundaries), but it might also be the case that we are stuck with values at the boundaries and don't end up in the typical set at all. Why don't you open up a PR, and we can discuss this in more detail?

@fritzo - Should we open a PR directly in torch.distributions, or aim to get a first pass merged in Pyro?

@fehiepsi
Copy link
Member

I think that we should move this distribution to pytorch (as mentioned by @ssnl at probtorch/pytorch#150) later, after testing the correctness in jit/cpu/gpu. In the mean time, I'll open a parallel PR with tests to address the correctness (should be top priority) and performance (not important right now) of the implementations.

@elbamos
Copy link
Contributor

elbamos commented Feb 11, 2019

@neerajprad Ok. And good point on the inverse mapping.

I'm curious about the decision to start sampling from wide values rather than narrow ones as Stan does. Is there any discussion about the rationale somewhere online that I can read through?

@eb8680
Copy link
Member

eb8680 commented Feb 11, 2019

@fehiepsi and @elbamos - I looked at the PR in question and mostly what I saw was the sort of reasonable but inevitable miscommunication that happens when smart, well-intentioned strangers discuss complicated math on the internet.

I've removed off-topic comments and @neerajprad and others will review #1746 . Please do not derail this issue.

@pyro-ppl pyro-ppl deleted a comment from elbamos Feb 11, 2019
@fritzo
Copy link
Member

fritzo commented Feb 11, 2019

Should we open a PR directly in torch.distributions

I think the best place for verbose code review of statistical functions is either Pyro or https://github.com/probtorch/pytorch . I'm fine starting in Pyro and later moving to PyTorch

@neerajprad
Copy link
Member

Closing this with @elbamos PR in #1753. If there are any follow-up items, let us create specific issues and address them separately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants