Skip to content

Commit

Permalink
Refactored homogeneity module.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed May 18, 2022
1 parent 42819e2 commit 7af4e3f
Show file tree
Hide file tree
Showing 5 changed files with 441 additions and 221 deletions.
47 changes: 32 additions & 15 deletions dcor/_hypothesis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,23 @@

import warnings
from dataclasses import dataclass
from typing import Any, Callable, Iterator
from typing import Any, Callable, Generic, Iterator, TypeVar

import numpy as np
from dcor._utils import ArrayType
from joblib import Parallel, delayed

from ._utils import _random_state_init
from ._utils import ArrayType, RandomLike, _random_state_init, get_namespace

T = TypeVar("T", bound=ArrayType)


@dataclass
class HypothesisTest():
class HypothesisTest(Generic[T]):
"""
Class containing the results of an hypothesis test.
"""
pvalue: float
statistic: ArrayType
statistic: T

@property
def p_value(self) -> float:
Expand Down Expand Up @@ -42,28 +46,40 @@ def __len__(self) -> int:


def _permuted_statistic(
matrix: ArrayType,
statistic_function: Callable[[ArrayType], ArrayType],
matrix: T,
statistic_function: Callable[[T], T],
permutation: np.typing.NDArray[int],
) -> ArrayType:
) -> T:

xp = get_namespace(matrix)

# We implicitly convert to NumPy for permuting the array if we don't
# have a take function.
# take is probably going to be included in the final version of the
# standard, so not much to worry about.
take = getattr(xp, "take", np.take)

permuted_rows = take(matrix, permutation, axis=0)
permuted_matrix = take(permuted_rows, permutation, axis=1)

permuted_matrix = matrix[np.ix_(permutation, permutation)]
# Transform back to the original type if NumPy conversion was needed.
permuted_matrix = xp.asarray(permuted_matrix)

return statistic_function(permuted_matrix)


def _permutation_test_with_sym_matrix(
matrix: ArrayType,
matrix: T,
*,
statistic_function: Callable[[ArrayType], ArrayType],
statistic_function: Callable[[T], T],
num_resamples: int,
random_state: np.random.RandomState | np.random.Generator | int | None,
random_state: RandomLike,
n_jobs: int | None = None,
) -> HypothesisTest:
) -> HypothesisTest[T]:
"""
Execute a permutation test in a symmetric matrix.
Parameters:
Args:
matrix: Matrix that will perform the permutation test.
statistic_function: Function that computes the desired statistic from
the matrix.
Expand All @@ -76,7 +92,8 @@ def _permutation_test_with_sym_matrix(
Results of the hypothesis test.
"""
matrix = np.asarray(matrix)
xp = get_namespace(matrix)
matrix = xp.asarray(matrix)
random_state = _random_state_init(random_state)

statistic = statistic_function(matrix)
Expand Down
12 changes: 9 additions & 3 deletions dcor/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
from __future__ import annotations

import enum
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, Union

import numpy as np

# TODO: Change in the future
if TYPE_CHECKING:
ArrayType = np.typing.NDArray[float]
ArrayType = np.typing.NDArray[np.number]
else:
ArrayType = np.ndarray

T = TypeVar("T", bound=ArrayType)
RandomLike = Union[
np.random.RandomState,
np.random.Generator,
int,
None,
]


class CompileMode(enum.Enum):
Expand Down Expand Up @@ -147,7 +153,7 @@ def _can_be_numpy_double(x: ArrayType) -> bool:


def _random_state_init(
random_state: np.random.RandomState | np.random.Generator | int | None,
random_state: RandomLike,
) -> np.random.RandomState | np.random.Generator:
"""
Initialize a RandomState object.
Expand Down

0 comments on commit 7af4e3f

Please sign in to comment.