Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex-Valued Gaussian distributions #83376

Open
egaznep opened this issue Aug 13, 2022 · 19 comments
Open

Complex-Valued Gaussian distributions #83376

egaznep opened this issue Aug 13, 2022 · 19 comments
Labels
module: complex Related to complex number support in PyTorch module: distributions Related to torch.distributions module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@egaznep
Copy link

egaznep commented Aug 13, 2022

🚀 The feature, motivation and pitch

I had been working on complex-valued variational auto encoders (CVAE) and for this purpose I required complex-valued Gaussian distributions. I coded my own implementation based on [1] and used it successfully. I essentially model it as a composite real-valued distribution. Would it be of interest if I created a PR to add this functionality?

[1] P. J. Schreier and L. L. Scharf, Statistical signal processing of complex-valued data: the theory of improper and noncircular signals. Cambridge: Cambridge University Press, 2010. Accessed: Feb. 13, 2022. [Online]. Available: https://doi.org/10.1017/CBO9780511815911

Alternatives

No response

Additional context

No response

cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @pbelevich

@ezyang ezyang added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: complex Related to complex number support in PyTorch labels Aug 15, 2022
@ezyang
Copy link
Contributor

ezyang commented Aug 15, 2022

Do you have an idea what the API for this functionality would be?

@lezcano
Copy link
Collaborator

lezcano commented Aug 15, 2022

I think there was already an open issue discussing this, but I have not been able to find it. cc @anjali411

The api for this would simply one where, if the mean tensor is complex, the returned tensor will be complex, and each of its 2n coordinates (n real and n imag) are distributed as a "Compex normal random vector" (see https://en.wikipedia.org/wiki/Complex_normal_distribution#Complex_normal_random_vector).

This would be equivalent to doing:

z_real = torch.normal(torch.zeros(n, 2), 1)
z = torch.view_as_complex(z_real)

Is this what you had in mind @egaznep?

@lezcano lezcano added the module: random Related to random number generation in PyTorch (rng generator) label Aug 15, 2022
@egaznep
Copy link
Author

egaznep commented Aug 15, 2022

I think there was already an open issue discussing this, but I have not been able to find it. cc @anjali411

The api for this would simply one where, if the mean tensor is complex, the returned tensor will be complex, and each of its 2n coordinates (n real and n imag) are distributed as a "Compex normal random vector" (see https://en.wikipedia.org/wiki/Complex_normal_distribution#Complex_normal_random_vector).

This would be equivalent to doing:

z_real = torch.normal(torch.zeros(n, 2), 1)
z = torch.view_as_complex(z_real)

Is this what you had in mind @egaznep?

I was thinking of everything according to the notation from the book I referred to, as people may want to define the distribution either using the 'composite real representation' or 'augmented complex representation'. Under the hood it probably makes more sense to represent everything using composite real, this way the new class can offload many tasks to the already existent MultivariateNormal distribution. I am not sure if what you proposed would be equivalent.

The code would look like this:

class ComplexMultivariateNormal(torch.distributions.Distribution):
    def __init__(self, loc: torch.Tensor, covariance=None, pseudocovariance=None):
        mu_z = torch.view_as_real(loc)
        R_zz = [[R_uu, R_uv], [R_uv, R_vv]] #R_{} solved from cov. and pseudocov. acc. to Schreier et al.
        self.z = torch.distributions.MultivariateNormal(loc=mu_z, covariance_matrix=R_zz) # composite real represented

Also, this perhaps makes the label 'module:distributions' more appropriate.

@lezcano
Copy link
Collaborator

lezcano commented Aug 15, 2022

Alas, I don't have access to the book. The wiki page gives 2 possible definitions for the Complex multivariate normal distribution. Could you have a look and see if the definition from that book is equivalent to any of the two in wikipedia?

@ezyang ezyang added the module: distributions Related to torch.distributions label Aug 15, 2022
@egaznep
Copy link
Author

egaznep commented Aug 15, 2022

https://en.wikipedia.org/wiki/Complex_normal_distribution#Complex_normal_random_vector

I think it follows, yes.

The "augmented complex representation" is formed by Mean, Covariance and Relation (Pseudocovariance). The latter section 'relations between covariance matrices' describes the equations which yields the composite real representation.

@lezcano
Copy link
Collaborator

lezcano commented Aug 16, 2022

Just to be sure, are you saying that the definition of the book is what they call "Complex standard normal random vector" in wikipedia, or the one that I'm referring to, which is called "Complex normal random vector" in wikipedia?

From the parameters you're mentioning, I believe you refer to the "Complex standard normal random vector" which, as you suggest, would be better fitted to live under torch.distributions. CC @fritzo

@egaznep
Copy link
Author

egaznep commented Aug 16, 2022

https://en.wikipedia.org/wiki/Complex_normal_distribution#Complex_normal_random_vector

It is the 'complex normal random vector'. Standard one is just a special case, with mean=O, covariance I and pseudocovariance O, isn't it? I an identity matrix and O a zeroes matrix. Still, I think distributions is a better place because there are cases where one may want to compute KL divergence or sample using reparametrization trick. These are not possible if this service is provided only using the 'random' module.

@lezcano
Copy link
Collaborator

lezcano commented Aug 16, 2022

You are completely right. I got a bit lost in the notation.

IMO, it makes complete sense to support this, and it may be even possible to support it within MultivariateNormal itself. The issue here is the usual one: This is nothing but a different parametrisation of a normal distribution of dimension 2n. I do not know how we currently support different parametrisations of the same distribution (in case we do),

@fritzo
Copy link
Collaborator

fritzo commented Aug 18, 2022

how we currently support different parametrisations of the same distribution

We're now leaning towards creating different distribution classes for each parametrization.

Early in the development of torch.distributions we thought it would be clean to overload different parameterizations for distributions, e.g. Categorical accepts either probs or logits; MultivariateNormal accepts either covariance_matrix or precision_matrix or scale_tril. However the internal special casing has led to many headaches and incompatibilities including: memoizing, pattern matching, subclassing, serialization, and conversions to/from other libraries.

I'm not sure how the complex normal would work, but I'd lean towards creating a new distribution class unless the existing MultivariateNormal can support it with no change to interface and say <10 lines change to code (code that is quite complex).

@lezcano
Copy link
Collaborator

lezcano commented Aug 19, 2022

It may be best to create a new class then, I think.

@egaznep would you be keen on having a go at this one?

@egaznep
Copy link
Author

egaznep commented Aug 19, 2022

@lezcano sure! I might not be able to attend to this right away, but I'll definitely do it.

@egaznep
Copy link
Author

egaznep commented Sep 8, 2022

class ComplexMultivariateNormal(MultivariateNormal):
    r"""
    Creates a complex-valued multivariate normal (also called Gaussian) 
    distribution. There are two complementary views [1]:

    - Augmented complex representation: parameterized by a complex-valued mean
    vector, and complex-valued covariance and pseudo-covariance matrices.

    - Composite real representation: parametrized by two real-valued mean vectors
    which are means of the real and the imaginary parts, and three covariance matrices:
    real, imaginary and cross-covariance.

    Args:
        augmented_complex (Tuple (mu_x, R_xx, R_xx_tilde)): consists of
        complex-valued mean, covariance and pseudo-covariance matrices.

        composite_real (Tuple (mu_u, mu_v, R_uu, R_vv, R_uv)): consists of
        complex-valued mean, covariance and pseudo-covariance matrices.

    Note: Under the hood a real-valued normal distribution (as in composite real 
    representation) is used to emulate the complex-valued distribution. However 
    first one needs to check if the given parameters define a complex-valued
    normal distribution and that is easier with augmented complex representation.
    
    [1] P. J. Schreier and L. L. Scharf, Statistical signal processing of complex-valued
    data: the theory of improper and noncircular signals. Cambridge: Cambridge University
    Press, 2010. Available: https://doi.org/10.1017/CBO9780511815911

    """
    def __init__(self, augmented_complex: Sequence[torch.Tensor] = None, composite_real: Sequence[torch.Tensor] = None):
        if (augmented_complex is not tuple) ^ (composite_real is not tuple):
            raise ValueError('Only one of the accepted representations must be supplied \
            and it must be supplied as a tuple.')
        if composite_real is tuple:
            if len(augmented_complex) != 5:
                raise ValueError(f'Composite real representation requires exactly 5 \
                    parameters. You supplied {len(composite_real)}.')
            mu_u, mu_v, R_uu, R_vv, R_uv = composite_real
            # obtain equivalent augmented complex representation
            mu_x = mu_u + 1j * mu_v
            R_xx = R_uu + R_vv + 1j * (R_uv.T - R_uv)
            R_xx_tilde = R_uu - R_vv + 1j * (R_uv.T + R_uv)
        if augmented_complex is tuple:
            if len(augmented_complex) != 3:
                raise ValueError(f'Augmented complex representation requires exactly 3 \
                    parameters. You supplied {len(augmented_complex)}.')
            # obtain equivalent composite real representation
            mu_x, R_xx, R_xx_tilde = augmented_complex
            mu_u, mu_v = mu_x.real, mu_x.imag
            R_uu = (R_xx + R_xx_tilde).real / 2
            R_vv = (R_xx - R_xx_tilde).real / 2
            R_uv = (R_xx_tilde - R_xx).imag / 2

        # check if provided parameters define a complex-valued Normal distribution
        # 1. R_xx is positive definite
        assert constraints.positive_definite(R_xx), 'R_xx (covariance matrix) is not positive definite.' 
        # 2. R_xx_tilde is symmetric
        assert constraints.symmetric(R_xx_tilde)
        # 3. Schur complement R_xx - R_xx_tilde inv(R_xx.conj) R_xx_tilde.conj is positive semidefinite
        assert constraints.positive_semidefinite(R_xx - R_xx_tilde @ R_xx.conj().inverse() @ R_xx_tilde.conj())

        # form the composite vectors
        mu_z = torch.cat((mu_u, mu_v), axis=-2)
        R_zz = torch.cat(
                (torch.cat((R_uu, R_uv), axis=-1),
                torch.cat((R_uu, R_uv), axis=-1)
            ), axis=-2)
        
        super().__init__(loc=mu_z, covariance_matrix=R_zz)

I am thinking of such an implementation. This still needs some work such as adding the relevant parameters as lazy_propertys and testing. Does this look good?

P.S. This could have been simpler but I don't know of a way to express the Schur complement condition as something that does not require the 'augmented complex' representation.

@lezcano
Copy link
Collaborator

lezcano commented Sep 8, 2022

There are a number of issues with the current implementation, (e.g. the computation of R_zz is wrong, i.e., it needs tests, and you are accepting Sequences but then you check that these are tuples rather than sequences. If you want tuples you should type it as a tuple).

Also, it's not clear to me that those 3 conditions are necessary or sufficient to check that the given parameters define a normal distribution. If anything, I think the necessary and sufficient condition would be for R_zz to be SPD, and this should be checked already within MultivariateNomal. As such, if the user already provides the composite_real parameters, I don't see the need to compute their complex counterparts.

That being said, the general idea looks fine to me, but perhaps @fritzo has better ideas (the API may be written in a cleaner way perhaps).

Also, it'd be good for the documentation to show mathematically what are the two ways of represent the CN distribution, and how they relate to each other in LaTeX for the docs to be self-contained, rather than citing a book.

In any case, these things would be better discussed in a proper PR.

PS. Even if I believe that the Schur complement is not necessary, if you ever want to compute it, a better way to do it would be:

R_xx - R_xx_tilde @ linalg.solve(R_xx,  R_xx_tilde).conj()

General tip: Never compute an inverse explicitly if you can avoid it, prefer linalg.solve.

@egaznep
Copy link
Author

egaznep commented Sep 8, 2022

Thanks for the review!

There are a number of issues with the current implementation, (e.g. the computation of R_zz is wrong, i.e., it needs tests, and you are accepting Sequences but then you check that these are tuples rather than sequences. If you want tuples you should type it as a tuple).

Typing required explicit type definitions for each tuple element and that was significantly lengthening the header. I wanted to avoid that by switching to a sequence. I can make the object a list for consistency or I am open to other recommendations. I am definitely open to learn better ways of doing things, this is my primary reason to be here in the first place.

Also, it's not clear to me that those 3 conditions are necessary or sufficient to check that the given parameters define a normal distribution. If anything, I think the necessary and sufficient condition would be for R_zz to be SPD, and this should be checked already within MultivariateNomal. As such, if the user already provides the composite_real parameters, I don't see the need to compute their complex counterparts.

That being said, the general idea looks fine to me, but perhaps @fritzo has better ideas (the API may be written in a cleaner way perhaps).

I haven't done the pen-and-paper math to confirm, but my intuition tells that not every 2N real-valued Gaussian is also a N real-valued complex Gaussian. The only way to know if this is the case is to substitute all the complex-valued matrices in the Schur complement with their real-valued matrix expressions. That requires a very careful evaluation of the inverse, using the matrix inversion lemma. I sadly do not have the time to run through that. If someone verifies that Schur complement condition degenerates to a simple equality that is always satisfied for a positive semidefinite R_zz then we are good to go. Otherwise I would hesitate to deviate from what book says.

Also, it'd be good for the documentation to show mathematically what are the two ways of represent the CN distribution, and how they relate to each other in LaTeX for the docs to be self-contained, rather than citing a book.

I do have them already typeset as this was part of my master's thesis. I can surely bring that in.

PS. Even if I believe that the Schur complement is not necessary, if you ever want to compute it, a better way to do it would be:

R_xx - R_xx_tilde @ linalg.solve(R_xx,  R_xx_tilde).conj()

In any case, these things would be better discussed in a proper PR.

Thanks for reminding. I normally pay attention to this but for this once I let it slip as my aim was to see if the direction this contribution goes is correct. The reason for me sharing this here rather than as a PR was also the same.

@fritzo
Copy link
Collaborator

fritzo commented Sep 8, 2022

Is there anything intrinsically complex in this distribution, or is it simply a multivariate distribution over the (n * 2)-long real vector? If the latter, then this would probably be better implemented as a

TransformedDistribution(MultivariateNormal(...), RealToComplexTransform())

where RealToComplexTransform lives in torch.distributions.transforms and converts between n-dimensional complex vectors and n * 2-dimensional real vectors. This approach also gives us a complex distributions for every real distribution already defined.

EDIT you'll also need to create complex and complex_vector constraints in torch.distributions.constraints. This is a little different than the original feature request but IMO is a quite powerful addition to the torch.distributions module.

@lezcano
Copy link
Collaborator

lezcano commented Sep 8, 2022

That is a very good point. You can also have a convenience helper to transform the complex parameters to real parameters, which would also fit this pattern.

@lezcano
Copy link
Collaborator

lezcano commented Sep 8, 2022

Also, I am inclined to think that complex 2n-dimensional real normal distributions are in correspondence to n-dimensional complex normal distributions because of a dimensionality count argument.
The Wikipedia page shows trasnformations between these two. To show the equivalence, you would need to show that if you apply the transformation Complex -> Real -> Complex you get the identity, and same for the transformations Real -> Complex -> Real. I haven't done it myself, but I'd be surprised if this weren't the case.

If this is indeed the case, then @fritzo's implementation (which is a functional version of your implementation if you remove the asserts) sounds like a great idea to me.

@felixdivo
Copy link
Contributor

I have proposed #92241 as a first step.

@lezcano
Copy link
Collaborator

lezcano commented Jan 25, 2023

@fritzo made in #92241 (review) some very valid points that should be taken into account by whoever wants to push this forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: complex Related to complex number support in PyTorch module: distributions Related to torch.distributions module: random Related to random number generation in PyTorch (rng generator) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants