Skip to content

Commit

Permalink
Corrected a bug on the rejection_sampling_2D algorithm and updated th…
Browse files Browse the repository at this point in the history
…e documentation

minor corrections

correct last pr number in whatsnew

Update pyriemann/datasets/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

Update pyriemann/datasets/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

Update pyriemann/datasets/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

Update pyriemann/datasets/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

modified whatsnew.rst

minor updates
  • Loading branch information
mhurte committed Jun 19, 2023
1 parent b0270a8 commit d32f2fb
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 19 deletions.
8 changes: 5 additions & 3 deletions doc/whatsnew.rst
Expand Up @@ -21,7 +21,7 @@ v0.5.dev

- Add tests for matrix operators and distances for HPD matrices, complete doc and add references. :pr:`234` by :user:`qbarthelemy`

- Enhance tangent space module to process HPD matrices, and complete tests. :pr:`236` by :user:`qbarthelemy`
- Enhance tangent space module to process HPD matrices. :pr:`236` by :user:`qbarthelemy`

- Fix regression introduced in :func:`pyriemann.spatialfilters.Xdawn` by :pr:`214`. :pr:`242` by :user:`qbarthelemy`

Expand All @@ -31,9 +31,11 @@ v0.5.dev

- Correct transform and predict_proba of :class:`pyriemann.classification.MeanField`. :pr:`247` by :user:`qbarthelemy`

- Enhance mean module to process HPD matrices, and complete tests. :pr:`243` by :user:`qbarthelemy`
- Enhance mean module to process HPD matrices. :pr:`243` by :user:`qbarthelemy`

- Correct :func:`pyriemann.utils.distance.distance_mahalanobis`, keeping only real part. :pr:`248` by :user:`qbarthelemy`
- Correct :func:`pyriemann.utils.distance.distance_mahalanobis`, keeping only real part. :pr:`249` by :user:`qbarthelemy`

- Fix an issue in :func:`pyriemann.datasets.sampling._rejection_sampling_2D`. :pr:`250` by :user: `mhurte`

v0.4 (Feb 2023)
---------------
Expand Down
39 changes: 23 additions & 16 deletions pyriemann/datasets/sampling.py
Expand Up @@ -65,16 +65,12 @@ def _rejection_sampling_2D_gfunction_plus(sigma, r_sample):
.. versionadded:: 0.4
"""
mu_a = np.array([sigma**2 / 2, -(sigma**2) / 2])
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] >= r_sample[1]:
num = (
np.exp(-1 / (2 * sigma**2) * np.sum(r_sample**2))
* np.sinh((r_sample[0] - r_sample[1]) / 2)
/ m
)
den = multivariate_normal.pdf(r_sample, mean=mu_a, cov=cov_matrix)
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_a, cov=cov_matrix)*m
return num / den
return 0

Expand All @@ -101,20 +97,18 @@ def _rejection_sampling_2D_gfunction_minus(sigma, r_sample):
.. versionadded:: 0.4
"""
mu_b = np.array([-(sigma**2) / 2, sigma**2 / 2])
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] < r_sample[1]:
num = (
np.exp(-1 / (2 * sigma**2) * np.sum(r_sample**2))
* np.sinh((r_sample[1] - r_sample[0]) / 2)
)
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_b, cov=cov_matrix) * m
return num / den
return 0


def _rejection_sampling_2D(n_samples, sigma, random_state=None):
def _rejection_sampling_2D(n_samples, sigma, random_state=None,
return_acceptance_rate=False):
"""Rejection sampling algorithm for the 2D case.
Implementation of a rejection sampling algorithm. The implementation
Expand All @@ -129,25 +123,35 @@ def _rejection_sampling_2D(n_samples, sigma, random_state=None):
Dispersion of the Riemannian Gaussian distribution.
random_state : int, RandomState instance or None, default=None
Pass an int for reproducible output across multiple function calls.
return_acceptance_rate : boolean, default=False
Whether to return the acceptance rate with the sample (number of
samples obtained divided by the number of samples generated by
the algorithm).
.. versionadded:: 0.5
Returns
-------
r_samples : ndarray, shape (n_samples, n_dim)
Samples of the r parameters of the Riemannian Gaussian distribution.
acceptance_rate : float
acceptance rate empirically computed for the generation of the sample
Notes
-----
.. versionadded:: 0.4
"""
mu_a = np.array([sigma**2 / 2, -(sigma**2) / 2])
mu_b = np.array([-(sigma**2) / 2, sigma**2 / 2])
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
r_samples = []
cpt = 0
acc = 0
rs = check_random_state(random_state)
while cpt != n_samples:
if rs.binomial(1, 0.5, 1) == 1:
acc += 1
if (rs.binomial(1, 0.5, 1) == 1):
r_sample = multivariate_normal.rvs(mu_a, cov_matrix, 1, rs)
res = _rejection_sampling_2D_gfunction_plus(sigma, r_sample)
if rs.rand(1) < res:
Expand All @@ -159,6 +163,9 @@ def _rejection_sampling_2D(n_samples, sigma, random_state=None):
if rs.rand(1) < res:
r_samples.append(r_sample)
cpt += 1

if return_acceptance_rate:
return np.array(r_samples), n_samples / acc
return np.array(r_samples)


Expand Down

0 comments on commit d32f2fb

Please sign in to comment.