Skip to content

Commit

Permalink
Typing _utils.
Browse files Browse the repository at this point in the history
  • Loading branch information
vnmabus committed Mar 3, 2022
1 parent 6010fa1 commit 4fa3a47
Show file tree
Hide file tree
Showing 2 changed files with 201 additions and 35 deletions.
106 changes: 72 additions & 34 deletions dcor/_utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
"""Utility functions"""
"""Utility functions."""

from __future__ import annotations

import enum
from typing import TYPE_CHECKING, Any, TypeVar

import numba
import numpy as np

# TODO: Change in the future
if TYPE_CHECKING:
ArrayType = np.typing.NDArray[Any]
else:
ArrayType = np.ndarray

T = TypeVar("T", bound=ArrayType)


class CompileMode(enum.Enum):
"""
Compilation mode of the algorithm.
"""
"""Compilation mode of the algorithm."""

AUTO = enum.auto()
"""
Expand All @@ -33,9 +41,7 @@ class CompileMode(enum.Enum):


class RowwiseMode(enum.Enum):
"""
Rowwise mode of the algorithm.
"""
"""Rowwise mode of the algorithm."""

AUTO = enum.auto()
"""
Expand All @@ -53,43 +59,70 @@ class RowwiseMode(enum.Enum):
"""


def _sqrt(x):
# TODO: Change the return type in the future
def get_namespace(*xs: Any) -> Any:
# `xs` contains one or more arrays, or possibly Python scalars (accepting
# those is a matter of taste, but doesn't seem unreasonable).
namespaces = {
x.__array_namespace__()
for x in xs if hasattr(x, '__array_namespace__')
}

if not namespaces:
# one could special-case np.ndarray above or use np.asarray here if
# older numpy versions need to be supported.
return np

if len(namespaces) != 1:
raise ValueError(
f"Multiple namespaces for array inputs: {namespaces}")

xp, = namespaces
if xp is None:
raise ValueError("The input is not a supported array type")

return xp


def _sqrt(x: T) -> T:
"""
Return square root of an ndarray.
Return square root of an array.
This sqrt function for ndarrays tries to use the exponentiation operator
if the objects stored do not supply a sqrt method.
Args:
x: Input array.
Returns:
Square root of the input array.
"""
x = np.clip(x, a_min=0, a_max=None)
# Replace negative values with 0
x = x * (x > 0)

xp = get_namespace(x)
try:
return np.sqrt(x)
return xp.sqrt(x)
except (AttributeError, TypeError):
exponent = 0.5

try:
exponent = np.take(x, 0).from_float(exponent)
except AttributeError:
pass
return x**0.5

return x ** exponent


def _transform_to_2d(t):
def _transform_to_2d(t: T) -> T:
"""Convert vectors to column matrices, to always have a 2d shape."""
t = np.asarray(t)
xp = get_namespace(t)
t = xp.asarray(t)

dim = len(t.shape)
assert dim <= 2

if dim < 2:
t = np.atleast_2d(t).T
t = xp.expand_dims(t, axis=1)

return t


def _can_be_double(x):
def _can_be_double(x: np.typing.NDArray[Any]) -> bool:
"""
Return if the array can be safely converted to double.
Expand All @@ -98,13 +131,20 @@ def _can_be_double(x):
converted to double (if the roundtrip conversion works).
"""
return ((np.issubdtype(x.dtype, np.floating) and
x.dtype.itemsize <= np.dtype(float).itemsize) or
(np.issubdtype(x.dtype, np.signedinteger) and
np.can_cast(x, float)))
return (
(
np.issubdtype(x.dtype, np.floating)
and x.dtype.itemsize <= np.dtype(float).itemsize
) or (
np.issubdtype(x.dtype, np.signedinteger)
and np.can_cast(x, float)
)
)


def _random_state_init(random_state):
def _random_state_init(
random_state: np.random.RandomState | np.random.Generator | int | None,
) -> np.random.RandomState | np.random.Generator:
"""
Initialize a RandomState object.
Expand All @@ -113,9 +153,7 @@ def _random_state_init(random_state):
and returned.
"""
try:
random_state = np.random.RandomState(random_state)
except TypeError:
pass
if isinstance(random_state, (np.random.RandomState, np.random.Generator)):
return random_state

return random_state
return np.random.RandomState(random_state)
130 changes: 129 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,132 @@ test=pytest

[tool:pytest]
addopts = --doctest-modules --cov=dcor
norecursedirs = '.*' 'build' 'dist' '*.egg' 'venv' .svn _build docs
norecursedirs = '.*' 'build' 'dist' '*.egg' 'venv' .svn _build docs

[flake8]
ignore =
# No docstring for magic methods
D105,
# No docstrings in __init__
D107,
# Ignore until https://github.com/terrencepreilly/darglint/issues/54 is closed
DAR202,
# Ignore until https://github.com/terrencepreilly/darglint/issues/144 is closed
DAR401,
# Non-explicit exceptions may be documented in raises
DAR402,
# Uppercase arguments like X are common in scikit-learn
N803,
# Uppercase variables like X are common in scikit-learn
N806,
# There are no bad quotes
Q000,
# Google Python style is not RST until after processed by Napoleon
# See https://github.com/peterjc/flake8-rst-docstrings/issues/17
RST201, RST203, RST301,
# assert is used by pytest tests
S101,
# Line break occurred before a binary operator (antipattern)
W503,
# Utils is used as a module name
WPS100,
# Short names like X or y are common in scikit-learn
WPS111,
# We do not like this underscored numbers convention
WPS114,
# Attributes in uppercase are used in enums
WPS115,
# Trailing underscores are a scikit-learn convention
WPS120,
# Cognitive complexity cannot be avoided at some modules
WPS232,
# The number of imported things may be large, especially for typing
WPS235,
# We like local imports, thanks
WPS300,
# Dotted imports are ok
WPS301,
# We love f-strings
WPS305,
# Implicit string concatenation is useful for exception messages
WPS306,
# No base class needed
WPS326,
# We allow multiline conditions
WPS337,
# We order methods differently
WPS338,
# We need multine loops
WPS352,
# Assign to a subcript slice is normal behaviour in numpy
WPS362,
# All keywords are beautiful
WPS420,
# We use nested imports sometimes, and it is not THAT bad
WPS433,
# We use list multiplication to allocate list with immutable values (None or numbers)
WPS435,
# Our private modules are fine to import
# (check https://github.com/wemake-services/wemake-python-styleguide/issues/1441)
WPS436,
# Our private objects are fine to import
WPS450,
# Numpy mixes bitwise and comparison operators
WPS465,
# Explicit len compare is better than implicit
WPS507,
# Comparison with not is not the same as with equality
WPS520,

per-file-ignores =
__init__.py:
# Unused modules are allowed in `__init__.py`, to reduce imports
F401,
# Import multiple names is allowed in `__init__.py`
WPS235,
# Logic is allowec in `__init__.py`
WPS412

# There are many datasets
_real_datasets.py: WPS202

# Tests benefit from meaningless zeros, magic numbers and fixtures
test_*.py: WPS339, WPS432, WPS442

# Examples are allowed to have imports in the middle, "commented code", call print and have magic numbers
plot_*.py: E402, E800, WPS421, WPS432

rst-directives =
# These are sorted alphabetically - but that does not matter
autosummary,data,currentmodule,deprecated,
footbibliography,glossary,
jupyter-execute,
moduleauthor,plot,testcode,
versionadded,versionchanged,

rst-roles =
attr,class,doc,footcite,footcite:ts,func,meth,mod,obj,ref,term,

allowed-domain-names = data, obj, result, results, val, value, values, var

# Needs to be tuned
max-arguments = 10
max-attributes = 10
max-cognitive-score = 30
max-expressions = 15
max-imports = 20
max-line-complexity = 30
max-local-variables = 15
max-methods = 30
max-module-expressions = 15
max-module-members = 15
max-string-usages = 10
max-try-body-length = 4

ignore-decorators = (property)|(overload)

strictness = long

# Beautify output and make it more informative
format = wemake
show-source = true

0 comments on commit 4fa3a47

Please sign in to comment.