Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster sampling 2D Riemannian Gaussian #198

Merged
merged 26 commits into from Aug 25, 2022

Conversation

Artim436
Copy link
Contributor

Introduction

Hi, everyone.

In this issue, I would like to share an improvement that @plcrodrigues and myself made for _sample_parameter_r, which is the main bottleneck for sampling SPD matrices from a Riemannian Gaussian distribution viasample_gaussian_spd.

Disclaimer: This improvement is only for n_dim=2 but we think it will be useful for the community, since many toy models can be reduced to the 2D case.

In summary, our idea is to use rejection sampling in _sample_parameter_r instead of the current implementation with slice sampling. To have a first look, here is a graph of the speedup with a fixed choice of sigma (i.e. the dispersion of the Gaussian distribution).

comparaison

Theory

Breaking the problem in two. We want to sample the vector $r = (r_1, \dots, r_m)$ from the probability distribution

$$ p(r) = \dfrac{1}{Z(m, \sigma)} \exp\left(-\frac{1}{2\sigma^2}\sum_{i = 1}^m r^2_i\right) \times \prod_{i < j}\sinh\left(|r_i - r_j|/2\right) $$

which has to be done via some computational method. At the moment, pyriemann uses a slice sampling approach [1] to obtain samples from $p(r)$ but this is rather slow. We suggest changing the implementation to a rejection sampling approach [2], exploring certain properties of $p$ and therefore obtaining better performance.

We consider at first the simplest case for our sampling procedure, that of $m = 2$. The pdf for $r = (r_1, r_2)$ simplifies to:

$$ \begin{array}{rcl} p(r_1, r_2) &=& \dfrac{1}{Z(2, \sigma)}\exp\left(-\dfrac{1}{2\sigma^2}(r_1^2 + r_2^2)\right) \times \sinh\left(|r_1 - r_2|/2\right) \end{array} $$

we can see this pdf as a mixture of two components depending on a binary variable $b$

$$ p(r) = p(r \mid b = 0) \times \mathbb{P}(b = 0) + p(r \mid b = 1) \times \mathbb{P}(b = 1) $$

with $\mathbb{P}(b = 0) = \mathbb{P}(b = 1) = 1/2$ and

$$ p(r \mid b = 0) = \dfrac{2}{Z(2, \sigma)}\exp\left(-\dfrac{1}{2\sigma^2}(r_1^2 + r_2^2)\right)\sinh\Big((r_1 - r_2)/2\Big) \times \mathbb{I}({{r_1 - r_2 \geq 0}}) $$

and

$$ p(r \mid b = 1) = \dfrac{2}{Z(2, \sigma)}\exp\left(-\dfrac{1}{2\sigma^2}(r_1^2 + r_2^2)\right) \sinh\Big((r_2 - r_1)/2\Big) \times \mathbb{I}({{r_1 - r_2 < 0}}) $$

where $\mathbb{I}$ is the indicator function.

Great, we see that to generate a sample from $p(r_1, r_2)$ we can first generate a Bernoulli variable $b \sim \mathcal{B}(1/2)$ and then sample from one of the conditional distributions. Now the question is: how do we sample from the conditional distributions?

Sampling from the conditional distributions. We can use rejection sampling to sample from $p(r \mid b = 0)$ and first we need to find a nice upper bound to it. Considering the inequality valid for all $x &gt; 0$:

$$ \sinh(x) = \dfrac{1}{2}\exp(x) - \dfrac{1}{2}\exp(-x) \leq \dfrac{1}{2}\exp(x) $$

we can write

$$ p(r \mid b = 0) \leq \dfrac{1}{Z(2, \sigma)} \exp\left(-\dfrac{1}{2\sigma^2}(r_1^2 + r_2^2)\right) \exp\Big((r_1 - r_2)/2\Big) \times \mathbb{I}({r_1 - r_2 \geq 0}) $$

and through some rearrangements,

$$ p(r \mid b = 0) \leq \dfrac{2\pi\sigma^2 \exp(\sigma^2/4)}{Z(2, \sigma)} \times \dfrac{1}{2\pi \sigma^2}\exp\left(-\dfrac{1}{2\sigma^2}\Big(\big(r_1-\sigma^2/2\big)^2 + \big(r_2+\sigma^2/2\big)^2\Big)\right) $$

The expression above indicates that if we want to sample from $p(r \mid b = 0)$ we can do it via rejection sampling using as auxiliary distribution

$$ g_+(r) = \dfrac{1}{2\pi \sigma^2}\exp\left(-\dfrac{1}{2\sigma^2}\Big(\big(r_1-\sigma^2/2\big)^2 + \big(r_2+\sigma^2/2\big)^2\Big)\right) $$

the algorithm goes as follows:

  1. Sample $u \sim \mathcal{U}(0, 1)$ and $r \sim g_+(r)$

  2. Check whether

$$ u < \exp\left(-\dfrac{1}{2\sigma^2}(r_1^2+r_2^2)\right)\sinh((r_1 - r_2)/2)\times \mathbb{I}({r_1 - r_2 \geq 0})\times\dfrac{1}{M g_+(r)} $$

  • If this holds, then accept $r$ as a sample from $p(r \mid b = 0)$
  • If not, reject the sample

Where $M = \pi\sigma^2 \exp(\sigma^2/4)$.

As you can imagine, sampling from $p(r \mid b = 1)$ follows the same logic but with a different auxiliary pdf $g_-(r)$

The algorithm

Summing up, the new implementation that we propose for _sample_parameter_r is based on the following algorithm:

  • Sample $b \sim \mathcal{B}(1/2)$
  • If $b = 0$, then sample $r \sim p(r \mid b = 0)$
  • If $b = 1$, then sample $r \sim p(r \mid b = 1)$

the sampling from the conditional distributions is done following the rejection sampling procedure described above.

Why not consider more dimensions ?

A natural question to ask is why we have not considered cases with more dimensions than just two dimensions? Well, things can get quite complicated when $m$ increases...

For instance, suppose we have $r = (r_1, r_2, r_3)$ then the pdf of interest can be written as

$$ p(r_1, r_2, r_3) = \dfrac{1}{Z(3,\sigma)}\exp\Big(-\dfrac{1}{2}(r_1^2 + r_2^2 + r_3^2)\Big) \times \sinh\Big(\dfrac{|r_1 - r_2|}{2}\Big) \times \sinh\Big(\dfrac{|r_1 - r_3|}{2}\Big) \times \sinh\Big(\dfrac{|r_2 - r_3|}{2}\Big) $$

To use the same strategy from our 2D example, we would have to sample three Bernoulli random variables (one for each factor in the product) and consider the $2^3 = 8$ possible combinations of signs inside each of the $\sinh$. We have tried to implement this case, but our first results indicate a the probability of acceptance that is too small and makes the rejection sampling algorithm impractical.

Moreover, we see that the number of conditional distributions increases as $2^m$ where $m$ is the dimensionality of the SPD matrices being considered. Therefore, our algorithm does not look scalable for larger matrix dimensions.

Final remarks

This is it, we have obtained a much faster implementation for sampling 2D Gaussian SPDs than what was available in pyriemann so far. We should mention that our implementation is based on a while loop that stops once we have obtained the desired number of samples. However, we can use certain properties of the rejection sampling algorithm to calculate the probability of acceptance of a sample and write code that generates several candidates in the upper hand. Such an algorithm can be even faster than the one we have implemented, but it requires approximating the normalizing constant of $p(r_1, r_2)$, which can be cumbersome. We will leave this extension for a future PR.

@@ -34,15 +34,18 @@
samples_1 = sample_gaussian_spd(n_matrices=n_matrices,
mean=mean,
sigma=sigma,
random_state=random_state)
random_state=random_state,
sampling_method='rejection')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would introduce a value called "auto" for the sampling_method method that would default to "rejection" if dim == 2 and slice otherwise so you don't have to expose these options in the tutorial and it will just become faster for users without any code change

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I understand your point. I have the impression that the default value "None" already fulfils the role of the auto variable you want to define.

pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
Artim436 and others added 5 commits August 19, 2022 14:12
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Co-authored-by: Alexandre Gramfort <alexandre.gramfort@m4x.org>
Copy link
Member

@qbarthelemy qbarthelemy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for this PR with such a detailed description!
Can you update test_sampling too?

pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
Copy link
Member

@agramfort agramfort left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not expose the sampling_method in the example. Just keep the default value of the parameter.

besides LGTM

examples/simulated/plot_riemannian_gaussian.py Outdated Show resolved Hide resolved
examples/simulated/plot_riemannian_gaussian.py Outdated Show resolved Hide resolved
examples/simulated/plot_riemannian_gaussian.py Outdated Show resolved Hide resolved
Copy link
Member

@sylvchev sylvchev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for the work, I agree with @agramfort about the auto rather than None, this the same idea but I think it is more clear that a choice is made behind the scene.
LGTM!

pyriemann/datasets/simulated.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Outdated Show resolved Hide resolved
tests/test_sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Show resolved Hide resolved
pyriemann/datasets/simulated.py Show resolved Hide resolved
tests/test_sampling.py Outdated Show resolved Hide resolved
tests/test_sampling.py Outdated Show resolved Hide resolved
pyriemann/datasets/sampling.py Show resolved Hide resolved
Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
@sylvchev
Copy link
Member

Thanks @Artim436 and all for this PR!

@sylvchev sylvchev merged commit bacebe6 into pyRiemann:master Aug 25, 2022
@Artim436 Artim436 deleted the faster_2D_riemannian_gaussian branch August 25, 2022 16:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants