Skip to content

Commit

Permalink
Merge pull request #33 from lemiceterieux/develop
Browse files Browse the repository at this point in the history
Parallelize the permutation test with joblib.
  • Loading branch information
vnmabus committed May 17, 2022
2 parents c8d10f6 + 205682a commit 841e08d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 12 deletions.
17 changes: 9 additions & 8 deletions dcor/_hypothesis.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import collections

import numpy as np

from joblib import Parallel, delayed
from ._utils import _random_state_init

HypothesisTest = collections.namedtuple('HypothesisTest', ['p_value',
'statistic'])


def _permutation_test_with_sym_matrix(matrix, statistic_function,
num_resamples, random_state):
num_resamples, random_state,n_jobs=1):
"""
Execute a permutation test in a symmetric matrix.
Expand All @@ -34,15 +34,16 @@ def _permutation_test_with_sym_matrix(matrix, statistic_function,

statistic = statistic_function(matrix)

bootstrap_statistics = np.ones(num_resamples, dtype=statistic.dtype)

for bootstrap in range(num_resamples):
permuted_index = random_state.permutation(matrix.shape[0])
def bootstrapPerms(mat):
permuted_index = random_state.permutation(mat.shape[0])

permuted_matrix = matrix[
permuted_matrix = mat[
np.ix_(permuted_index, permuted_index)]

bootstrap_statistics[bootstrap] = statistic_function(permuted_matrix)
return statistic_function(permuted_matrix)

bootstrap_statistics = Parallel(n_jobs=n_jobs)(delayed(bootstrapPerms)(matrix) for bootstrap in range(num_resamples))
bootstrap_statistics = np.array(bootstrap_statistics, dtype=statistic.dtype)

extreme_results = bootstrap_statistics > statistic
p_value = (np.sum(extreme_results) + 1.0) / (num_resamples + 1)
Expand Down
6 changes: 4 additions & 2 deletions dcor/homogeneity.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def energy_test(
exponent=1,
random_state=None,
average=None,
estimation_stat=_energy.EstimationStatistic.V_STATISTIC
estimation_stat=_energy.EstimationStatistic.V_STATISTIC,
n_jobs=1,
):
"""
Test of homogeneity based on the energy distance.
Expand Down Expand Up @@ -259,4 +260,5 @@ def statistic_function(distance_matrix):
sample_distances,
statistic_function=statistic_function,
num_resamples=num_resamples,
random_state=random_state)
random_state=random_state,
n_jobs=n_jobs)
8 changes: 6 additions & 2 deletions dcor/independence.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def distance_covariance_test(
num_resamples=0,
exponent=1,
random_state=None,
n_jobs=1,
):
"""
Test of distance covariance independence.
Expand Down Expand Up @@ -111,7 +112,8 @@ def statistic_function(distance_matrix):
u_x,
statistic_function=statistic_function,
num_resamples=num_resamples,
random_state=random_state)
random_state=random_state,
n_jobs=n_jobs)


def partial_distance_covariance_test(
Expand All @@ -122,6 +124,7 @@ def partial_distance_covariance_test(
num_resamples=0,
exponent=1,
random_state=None,
n_jobs=1,
):
"""
Test of partial distance covariance independence.
Expand Down Expand Up @@ -216,7 +219,8 @@ def statistic_function(distance_matrix):
p_xz,
statistic_function=statistic_function,
num_resamples=num_resamples,
random_state=random_state)
random_state=random_state,
n_jobs=n_jobs)


def distance_correlation_t_statistic(x, y):
Expand Down

0 comments on commit 841e08d

Please sign in to comment.