-
Notifications
You must be signed in to change notification settings - Fork 22.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Complex-Valued Gaussian distributions #83376
Comments
Do you have an idea what the API for this functionality would be? |
I think there was already an open issue discussing this, but I have not been able to find it. cc @anjali411 The api for this would simply one where, if the This would be equivalent to doing: z_real = torch.normal(torch.zeros(n, 2), 1)
z = torch.view_as_complex(z_real) Is this what you had in mind @egaznep? |
I was thinking of everything according to the notation from the book I referred to, as people may want to define the distribution either using the 'composite real representation' or 'augmented complex representation'. Under the hood it probably makes more sense to represent everything using composite real, this way the new class can offload many tasks to the already existent The code would look like this:
Also, this perhaps makes the label 'module:distributions' more appropriate. |
Alas, I don't have access to the book. The wiki page gives 2 possible definitions for the Complex multivariate normal distribution. Could you have a look and see if the definition from that book is equivalent to any of the two in wikipedia? |
I think it follows, yes. The "augmented complex representation" is formed by Mean, Covariance and Relation (Pseudocovariance). The latter section 'relations between covariance matrices' describes the equations which yields the composite real representation. |
Just to be sure, are you saying that the definition of the book is what they call "Complex standard normal random vector" in wikipedia, or the one that I'm referring to, which is called "Complex normal random vector" in wikipedia? From the parameters you're mentioning, I believe you refer to the "Complex standard normal random vector" which, as you suggest, would be better fitted to live under |
It is the 'complex normal random vector'. Standard one is just a special case, with mean=O, covariance I and pseudocovariance O, isn't it? I an identity matrix and O a zeroes matrix. Still, I think distributions is a better place because there are cases where one may want to compute KL divergence or sample using reparametrization trick. These are not possible if this service is provided only using the 'random' module. |
You are completely right. I got a bit lost in the notation. IMO, it makes complete sense to support this, and it may be even possible to support it within |
We're now leaning towards creating different distribution classes for each parametrization. Early in the development of torch.distributions we thought it would be clean to overload different parameterizations for distributions, e.g. I'm not sure how the complex normal would work, but I'd lean towards creating a new distribution class unless the existing |
It may be best to create a new class then, I think. @egaznep would you be keen on having a go at this one? |
@lezcano sure! I might not be able to attend to this right away, but I'll definitely do it. |
class ComplexMultivariateNormal(MultivariateNormal):
r"""
Creates a complex-valued multivariate normal (also called Gaussian)
distribution. There are two complementary views [1]:
- Augmented complex representation: parameterized by a complex-valued mean
vector, and complex-valued covariance and pseudo-covariance matrices.
- Composite real representation: parametrized by two real-valued mean vectors
which are means of the real and the imaginary parts, and three covariance matrices:
real, imaginary and cross-covariance.
Args:
augmented_complex (Tuple (mu_x, R_xx, R_xx_tilde)): consists of
complex-valued mean, covariance and pseudo-covariance matrices.
composite_real (Tuple (mu_u, mu_v, R_uu, R_vv, R_uv)): consists of
complex-valued mean, covariance and pseudo-covariance matrices.
Note: Under the hood a real-valued normal distribution (as in composite real
representation) is used to emulate the complex-valued distribution. However
first one needs to check if the given parameters define a complex-valued
normal distribution and that is easier with augmented complex representation.
[1] P. J. Schreier and L. L. Scharf, Statistical signal processing of complex-valued
data: the theory of improper and noncircular signals. Cambridge: Cambridge University
Press, 2010. Available: https://doi.org/10.1017/CBO9780511815911
"""
def __init__(self, augmented_complex: Sequence[torch.Tensor] = None, composite_real: Sequence[torch.Tensor] = None):
if (augmented_complex is not tuple) ^ (composite_real is not tuple):
raise ValueError('Only one of the accepted representations must be supplied \
and it must be supplied as a tuple.')
if composite_real is tuple:
if len(augmented_complex) != 5:
raise ValueError(f'Composite real representation requires exactly 5 \
parameters. You supplied {len(composite_real)}.')
mu_u, mu_v, R_uu, R_vv, R_uv = composite_real
# obtain equivalent augmented complex representation
mu_x = mu_u + 1j * mu_v
R_xx = R_uu + R_vv + 1j * (R_uv.T - R_uv)
R_xx_tilde = R_uu - R_vv + 1j * (R_uv.T + R_uv)
if augmented_complex is tuple:
if len(augmented_complex) != 3:
raise ValueError(f'Augmented complex representation requires exactly 3 \
parameters. You supplied {len(augmented_complex)}.')
# obtain equivalent composite real representation
mu_x, R_xx, R_xx_tilde = augmented_complex
mu_u, mu_v = mu_x.real, mu_x.imag
R_uu = (R_xx + R_xx_tilde).real / 2
R_vv = (R_xx - R_xx_tilde).real / 2
R_uv = (R_xx_tilde - R_xx).imag / 2
# check if provided parameters define a complex-valued Normal distribution
# 1. R_xx is positive definite
assert constraints.positive_definite(R_xx), 'R_xx (covariance matrix) is not positive definite.'
# 2. R_xx_tilde is symmetric
assert constraints.symmetric(R_xx_tilde)
# 3. Schur complement R_xx - R_xx_tilde inv(R_xx.conj) R_xx_tilde.conj is positive semidefinite
assert constraints.positive_semidefinite(R_xx - R_xx_tilde @ R_xx.conj().inverse() @ R_xx_tilde.conj())
# form the composite vectors
mu_z = torch.cat((mu_u, mu_v), axis=-2)
R_zz = torch.cat(
(torch.cat((R_uu, R_uv), axis=-1),
torch.cat((R_uu, R_uv), axis=-1)
), axis=-2)
super().__init__(loc=mu_z, covariance_matrix=R_zz) I am thinking of such an implementation. This still needs some work such as adding the relevant parameters as P.S. This could have been simpler but I don't know of a way to express the Schur complement condition as something that does not require the 'augmented complex' representation. |
There are a number of issues with the current implementation, (e.g. the computation of Also, it's not clear to me that those 3 conditions are necessary or sufficient to check that the given parameters define a normal distribution. If anything, I think the necessary and sufficient condition would be for That being said, the general idea looks fine to me, but perhaps @fritzo has better ideas (the API may be written in a cleaner way perhaps). Also, it'd be good for the documentation to show mathematically what are the two ways of represent the CN distribution, and how they relate to each other in LaTeX for the docs to be self-contained, rather than citing a book. In any case, these things would be better discussed in a proper PR. PS. Even if I believe that the Schur complement is not necessary, if you ever want to compute it, a better way to do it would be: R_xx - R_xx_tilde @ linalg.solve(R_xx, R_xx_tilde).conj() General tip: Never compute an inverse explicitly if you can avoid it, prefer |
Thanks for the review!
Typing required explicit type definitions for each tuple element and that was significantly lengthening the header. I wanted to avoid that by switching to a sequence. I can make the object a
I haven't done the pen-and-paper math to confirm, but my intuition tells that not every 2N real-valued Gaussian is also a N real-valued complex Gaussian. The only way to know if this is the case is to substitute all the complex-valued matrices in the Schur complement with their real-valued matrix expressions. That requires a very careful evaluation of the inverse, using the matrix inversion lemma. I sadly do not have the time to run through that. If someone verifies that Schur complement condition degenerates to a simple equality that is always satisfied for a positive semidefinite
I do have them already typeset as this was part of my master's thesis. I can surely bring that in.
Thanks for reminding. I normally pay attention to this but for this once I let it slip as my aim was to see if the direction this contribution goes is correct. The reason for me sharing this here rather than as a PR was also the same. |
Is there anything intrinsically complex in this distribution, or is it simply a multivariate distribution over the
where EDIT you'll also need to create |
That is a very good point. You can also have a convenience helper to transform the complex parameters to real parameters, which would also fit this pattern. |
Also, I am inclined to think that complex 2n-dimensional real normal distributions are in correspondence to n-dimensional complex normal distributions because of a dimensionality count argument. If this is indeed the case, then @fritzo's implementation (which is a functional version of your implementation if you remove the asserts) sounds like a great idea to me. |
I have proposed #92241 as a first step. |
@fritzo made in #92241 (review) some very valid points that should be taken into account by whoever wants to push this forward. |
🚀 The feature, motivation and pitch
I had been working on complex-valued variational auto encoders (CVAE) and for this purpose I required complex-valued Gaussian distributions. I coded my own implementation based on [1] and used it successfully. I essentially model it as a composite real-valued distribution. Would it be of interest if I created a PR to add this functionality?
[1] P. J. Schreier and L. L. Scharf, Statistical signal processing of complex-valued data: the theory of improper and noncircular signals. Cambridge: Cambridge University Press, 2010. Accessed: Feb. 13, 2022. [Online]. Available: https://doi.org/10.1017/CBO9780511815911
Alternatives
No response
Additional context
No response
cc @fritzo @neerajprad @alicanb @nikitaved @ezyang @anjali411 @dylanbespalko @mruberry @lezcano @pbelevich
The text was updated successfully, but these errors were encountered: