Skip to content

Commit

Permalink
Corrected a bug on the rejection_sampling_2D algorithm and updated th… (
Browse files Browse the repository at this point in the history
#250)

* Corrected a bug on the rejection_sampling_2D algorithm and updated the documentation

Update pyriemann/datasets/sampling.py

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>

* correct last pr number in whatsnew

* Solving whatsnew conflict

* Solving whatsnew conflict 2

* tweaking test_tlrotate parameters to have a reasonably good example for unit test

* typo correction

* last corrections

---------

Co-authored-by: Quentin Barthélemy <q.barthelemy@gmail.com>
Co-authored-by: Pedro L. C. Rodrigues <pedro.rodrigues01@gmail.com>
  • Loading branch information
3 people committed Jun 19, 2023
1 parent 672b6b5 commit e496485
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 17 deletions.
2 changes: 2 additions & 0 deletions doc/whatsnew.rst
Expand Up @@ -35,6 +35,8 @@ v0.5.dev

- Correct :func:`pyriemann.utils.distance.distance_mahalanobis`, keeping only real part. :pr:`249` by :user:`qbarthelemy`

- Fix :func:`pyriemann.datasets.sampling.sample_gaussian_spd` used with ``sampling_method=rejection`` on 2D matrices. :pr:`250` by :user:`mhurte`

v0.4 (Feb 2023)
---------------

Expand Down
40 changes: 24 additions & 16 deletions pyriemann/datasets/sampling.py
Expand Up @@ -65,16 +65,12 @@ def _rejection_sampling_2D_gfunction_plus(sigma, r_sample):
.. versionadded:: 0.4
"""
mu_a = np.array([sigma**2 / 2, -(sigma**2) / 2])
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] >= r_sample[1]:
num = (
np.exp(-1 / (2 * sigma**2) * np.sum(r_sample**2))
* np.sinh((r_sample[0] - r_sample[1]) / 2)
/ m
)
den = multivariate_normal.pdf(r_sample, mean=mu_a, cov=cov_matrix)
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_a, cov=cov_matrix)*m
return num / den
return 0

Expand All @@ -101,20 +97,18 @@ def _rejection_sampling_2D_gfunction_minus(sigma, r_sample):
.. versionadded:: 0.4
"""
mu_b = np.array([-(sigma**2) / 2, sigma**2 / 2])
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
m = np.pi * (sigma**2) * np.exp(sigma**2 / 4)
if r_sample[0] < r_sample[1]:
num = (
np.exp(-1 / (2 * sigma**2) * np.sum(r_sample**2))
* np.sinh((r_sample[1] - r_sample[0]) / 2)
)
num = _pdf_r(r_sample, sigma)
den = multivariate_normal.pdf(r_sample, mean=mu_b, cov=cov_matrix) * m
return num / den
return 0


def _rejection_sampling_2D(n_samples, sigma, random_state=None):
def _rejection_sampling_2D(n_samples, sigma, random_state=None,
return_acceptance_rate=False):
"""Rejection sampling algorithm for the 2D case.
Implementation of a rejection sampling algorithm. The implementation
Expand All @@ -129,25 +123,36 @@ def _rejection_sampling_2D(n_samples, sigma, random_state=None):
Dispersion of the Riemannian Gaussian distribution.
random_state : int, RandomState instance or None, default=None
Pass an int for reproducible output across multiple function calls.
return_acceptance_rate : boolean, default=False
Whether to return the acceptance rate with the sample (number of
samples obtained divided by the number of samples generated by
the algorithm).
.. versionadded:: 0.5
Returns
-------
r_samples : ndarray, shape (n_samples, n_dim)
Samples of the r parameters of the Riemannian Gaussian distribution.
acceptance_rate : float
Acceptance rate empirically computed for the generation of the sample.
Only returned if ``return_acceptance_rate=True``.
Notes
-----
.. versionadded:: 0.4
"""
mu_a = np.array([sigma**2 / 2, -(sigma**2) / 2])
mu_b = np.array([-(sigma**2) / 2, sigma**2 / 2])
mu_a = np.array([-sigma**2 / 2, (sigma**2) / 2])
mu_b = np.array([(sigma**2) / 2, -sigma**2 / 2])
cov_matrix = (sigma**2) * np.eye(2)
r_samples = []
cpt = 0
acc = 0
rs = check_random_state(random_state)
while cpt != n_samples:
if rs.binomial(1, 0.5, 1) == 1:
acc += 1
if (rs.binomial(1, 0.5, 1) == 1):
r_sample = multivariate_normal.rvs(mu_a, cov_matrix, 1, rs)
res = _rejection_sampling_2D_gfunction_plus(sigma, r_sample)
if rs.rand(1) < res:
Expand All @@ -159,6 +164,9 @@ def _rejection_sampling_2D(n_samples, sigma, random_state=None):
if rs.rand(1) < res:
r_samples.append(r_sample)
cpt += 1

if return_acceptance_rate:
return np.array(r_samples), n_samples / acc
return np.array(r_samples)


Expand Down
2 changes: 1 addition & 1 deletion tests/test_transfer.py
Expand Up @@ -99,7 +99,7 @@ def test_tlrotate(rndstate, metric):
"""Test pipeline for rotating the datasets"""
# check if the distance between the classes of each domain is reduced
X, y_enc = make_classification_transfer(
n_matrices=25, class_sep=5, class_disp=1.0, random_state=rndstate)
n_matrices=50, class_sep=3, class_disp=1.0, random_state=rndstate)
rct = TLCenter(target_domain='target_domain')
X_rct = rct.fit_transform(X, y_enc)
rot = TLRotate(target_domain='target_domain', metric=metric)
Expand Down

0 comments on commit e496485

Please sign in to comment.