Skip to content

Commit

Permalink
Fix type errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Dec 26, 2022
1 parent 8e063d2 commit 3fb5002
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 28 deletions.
42 changes: 41 additions & 1 deletion dcor/_dcor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,46 @@ def __call__(
...


@overload
def _dcov_terms_auto(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[False] = False,
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
None,
None,
]:
...


@overload
def _dcov_terms_auto(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[True],
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
Array,
Array,
]:
...


def _dcov_terms_auto(
x: Array,
y: Array,
Expand Down Expand Up @@ -656,7 +696,7 @@ def u_distance_stats_sqr(
exponent: float = 1,
method: DistanceCovarianceMethodLike = DistanceCovarianceMethod.AUTO,
compile_mode: CompileMode = CompileMode.AUTO,
) -> Array:
) -> Stats[Array]:
"""
Unbiased statistics related with the squared distance covariance.
Expand Down
86 changes: 71 additions & 15 deletions dcor/_fast_dcov_avl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,20 @@

import math
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Tuple,
TypeVar,
overload,
)

import numba
import numpy as np
from numba import boolean, float64, int64
from numba.types import Tuple
from numba.types import Tuple as NumbaTuple

from ._dcor_internals_numba import (
NumbaIntVectorReadOnly,
Expand All @@ -29,7 +37,7 @@
NumpyArrayType = np.ndarray


T = TypeVar("T", bound=NumpyArrayType)
Array = TypeVar("Array", bound=NumpyArrayType)


def _dyad_update(
Expand Down Expand Up @@ -278,7 +286,7 @@ def _get_impl_args(


_get_impl_args_compiled = numba.njit(
Tuple((
NumbaTuple((
NumbaVectorReadOnly,
NumbaVectorReadOnly,
NumbaIntVectorReadOnly,
Expand Down Expand Up @@ -412,7 +420,7 @@ def _distance_covariance_sqr_terms_avl_impl(
compiled=False,
)
_distance_covariance_sqr_terms_avl_impl_compiled = numba.njit(
Tuple((
NumbaTuple((
float64,
NumbaVector,
float64,
Expand Down Expand Up @@ -462,14 +470,62 @@ def _distance_covariance_sqr_terms_avl_impl(
}


@overload
def _distance_covariance_sqr_terms_avl(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[False] = False,
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
None,
None,
]:
...


@overload
def _distance_covariance_sqr_terms_avl(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[True],
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
Array,
Array,
]:
...


def _distance_covariance_sqr_terms_avl(
x: T,
y: T,
x: Array,
y: Array,
*,
exponent: float = 1,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: bool = False,
) -> T:
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
None,
None,
]:
"""Fast algorithm for the squared distance covariance terms."""
if exponent != 1:
raise ValueError(f"Exponent should be 1 but is {exponent} instead.")
Expand Down Expand Up @@ -509,11 +565,11 @@ def _generate_rowwise_internal(
) -> Callable[..., NumpyArrayType]:

def _rowwise_distance_covariance_sqr_avl_generic_internal(
x: T,
y: T,
x: Array,
y: Array,
unbiased: bool,
res: T,
) -> T:
res: Array,
) -> Array:

res[0] = _distance_covariance_sqr_avl_impl_compiled(x, y, unbiased)

Expand Down Expand Up @@ -549,12 +605,12 @@ def _rowwise_distance_covariance_sqr_avl_generic_internal(


def _rowwise_distance_covariance_sqr_avl_generic(
x: T,
y: T,
x: Array,
y: Array,
exponent: float = 1,
unbiased: bool = False,
compile_mode: CompileMode = CompileMode.AUTO,
) -> T:
) -> Array:

if exponent != 1:
raise ValueError(f"Exponent should be 1 but is {exponent} instead.")
Expand Down
76 changes: 68 additions & 8 deletions dcor/_fast_dcov_mergesort.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Tuple,
TypeVar,
overload,
)

import numba
import numpy as np
from numba import boolean, float64
from numba.types import Tuple
from numba.types import Tuple as NumbaTuple

from ._dcor_internals_numba import (
NumbaMatrix,
Expand All @@ -26,7 +34,7 @@
else:
NumpyArrayType = np.ndarray

T = TypeVar("T", bound=NumpyArrayType)
Array = TypeVar("Array", bound=NumpyArrayType)


def _compute_weight_sums(
Expand Down Expand Up @@ -264,15 +272,19 @@ def _distance_covariance_sqr_terms_mergesort_impl(
)
)
_distance_covariance_sqr_terms_mergesort_impl_compiled = numba.njit(
Tuple((
NumbaTuple((
float64,
NumbaVector,
float64,
NumbaVector,
float64,
numba.optional(float64),
numba.optional(float64),
))(NumbaVectorReadOnlyNonContiguous, NumbaVectorReadOnlyNonContiguous, boolean),
))(
NumbaVectorReadOnlyNonContiguous,
NumbaVectorReadOnlyNonContiguous,
boolean,
),
cache=True,
)(
_generate_distance_covariance_sqr_terms_mergesort_impl(
Expand Down Expand Up @@ -317,14 +329,62 @@ def _distance_covariance_sqr_terms_mergesort_impl(
}


@overload
def _distance_covariance_sqr_terms_mergesort(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[False] = False,
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
None,
None,
]:
...


@overload
def _distance_covariance_sqr_terms_mergesort(
__x: Array,
__y: Array,
*,
exponent: float,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: Literal[True],
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
Array,
Array,
]:
...


def _distance_covariance_sqr_terms_mergesort(
x: T,
y: T,
x: Array,
y: Array,
*,
exponent: float = 1,
compile_mode: CompileMode = CompileMode.AUTO,
return_var_terms: bool = False,
) -> T:
) -> Tuple[
Array,
Array,
Array,
Array,
Array,
Array | None,
Array | None,
]:

if exponent != 1:
raise ValueError(f"Exponent should be 1 but is {exponent} instead.")
Expand Down
8 changes: 4 additions & 4 deletions dcor/_partial_dcor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def partial_distance_covariance(
y: ArrayType,
z: ArrayType,
) -> ArrayType:
"""
r"""
Partial distance covariance estimator.
Compute the estimator for the partial distance covariance of the
random vectors corresponding to :math:`x` and :math:`y` with respect
to the random variable corresponding to :math:`z`.
Warning:
Warning:
Partial distance covariance should be used carefully as it presents
some undesirable or counterintuitive properties. In particular, the
reader cannot assume that :math:`\mathcal{V}^{*}` characterizes
Expand Down Expand Up @@ -92,14 +92,14 @@ def partial_distance_correlation(
y: ArrayType,
z: ArrayType,
) -> ArrayType: # pylint:disable=too-many-locals
"""
r"""
Partial distance correlation estimator.
Compute the estimator for the partial distance correlation of the
random vectors corresponding to :math:`x` and :math:`y` with respect
to the random variable corresponding to :math:`z`.
Warning:
Warning:
Partial distance correlation should be used carefully as it presents
some undesirable or counterintuitive properties. In particular, the
reader cannot assume that :math:`\mathcal{R}^{*}` characterizes
Expand Down

0 comments on commit 3fb5002

Please sign in to comment.