Skip to content

Commit

Permalink
Make tests pytest parametrized
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Apr 25, 2024
1 parent 4b9515d commit 67881f7
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 42 deletions.
4 changes: 3 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
numpydoc_validation_exclude = {
r"\.__weakref__$",
r"\.__repr__$",
r"janus_core\.janus_types",
r"janus_core\.helpers\.janus_types",
}
numpydoc_class_members_toctree = False

Expand Down Expand Up @@ -192,4 +192,6 @@
nitpick_ignore = [
("py:class", "Logger"),
("py:class", "numpy.float64"),
("py:class", "ellipsis"),
("py:class", "T")
]
102 changes: 87 additions & 15 deletions janus_core/helpers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

from functools import singledispatchmethod
import re
from types import EllipsisType
from typing import Iterator, TypeVar

from numpy import float64, genfromtxt, zeros
from numpy.typing import NDArray

from janus_core.helpers.janus_types import PathLike

try: # Python >=3.10
from types import EllipsisType
except ImportError:
EllipsisType = type(...)

T = TypeVar("T")


class Stats:
"""
Expand Down Expand Up @@ -41,29 +48,82 @@ def __init__(self, source: PathLike) -> None:
self.read()

@singledispatchmethod
def __getitem__(self, ind):
def _getind(self, lab: T) -> T: # pylint: disable=no-self-use
"""
Convert an index label from str to int if present in labels.
Otherwise return the input.
Parameters
----------
lab : str
Label to find.
Returns
-------
int
Index of label in self or input if not string.
Raises
------
IndexError
Label not found in labels.
"""
return lab

@_getind.register
def _(self, lab: str) -> int:
# Case-insensitive fuzzy match, only has to be `in` the labels
index = next(
(
index
for index, label in enumerate(self.labels)
if lab.lower() in label.lower()
),
None,
)
if index is None:
raise IndexError(f"{lab} not found in labels")
return index

@singledispatchmethod
def __getitem__(self, ind) -> NDArray[float64]:
"""
Get member of stats data by label or index.
Parameters
----------
ind : Any
Index or label to find.
Returns
-------
NDArray[float64]
Columns of data by label.
Raises
------
IndexError
Invalid index type or label not found in labels.
"""
raise IndexError(f"Unknown index {ind}")

@__getitem__.register(int)
@__getitem__.register(slice)
@__getitem__.register(EllipsisType)
def _(self, ind):
def _(self, ind) -> NDArray[float64]:
return self.data[:, ind]

@__getitem__.register(list)
@__getitem__.register(tuple)
def _(self, ind):
return self.data[ind]
def _(self, ind) -> NDArray[float64]:
ind = list(map(self._getind, ind))
return self.data[:, ind]

@__getitem__.register(str)
def _(self, ind):
# Case-insensitive fuzzy match, only has to be `in` the labels
index = next((index
for index, label in enumerate(self.labels)
if ind.lower() in label.lower()),
None)
if index is None:
raise IndexError(f"{ind} not found in labels")
return self[index]
def _(self, ind) -> NDArray[float64]: # numpydoc ignore=GL08
ind = self._getind(ind)
return self[ind]

@property
def rows(self) -> int:
Expand Down Expand Up @@ -137,6 +197,18 @@ def data(self) -> NDArray[float64]:
"""
return self._data

@property
def data_tags(self) -> Iterator[tuple[str, str]]:
"""
Return the labels and their units together.
Returns
-------
Iterator[tuple[str, str]]
Zipped labels and units
"""
return zip(self.labels, self.units)

def read(self) -> None:
"""
Read MD stats and store them in `data`.
Expand All @@ -161,6 +233,6 @@ def __repr__(self) -> str:

header = f"contains {self.columns} timeseries, each with {self.rows} elements"
header += "\nindex label units"
for index, (label, unit) in enumerate(zip(self.labels, self.units)):
for index, (label, unit) in enumerate(self.data_tags):
header += f"\n{index} {label} {unit}"
return header
96 changes: 70 additions & 26 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,81 @@

from pathlib import Path

import pytest
from pytest import approx

from janus_core.helpers.stats import Stats

DATA_PATH = Path(__file__).parent / "data"


def test_stats(capsys):
"""Test reading md stats."""
data_path = DATA_PATH / "md-stats.dat"

stat_data = Stats(data_path)

assert stat_data.rows == 100
assert stat_data.columns == 18
assert stat_data.data[99, 17] == approx(300.0)
assert stat_data.units[0] == ""
assert stat_data.units[17] == "K"
assert stat_data.labels[0] == "# Step"
assert stat_data.labels[17] == "Target T"

# Check getitem form
assert (stat_data["target t"] == stat_data.data[:, 17]).all()
assert (stat_data[:, 17] == stat_data.data[:, 17]).all()
assert (stat_data[17] == stat_data.data[:, 17]).all()

print(stat_data)
std_out_err = capsys.readouterr()
assert std_out_err.err == ""
assert "index label units" in std_out_err.out
assert (
f"contains {stat_data.columns} timeseries, each with {stat_data.rows} elements"
in std_out_err.out
class TestStats:
"""Tests for the stats type"""

data = Stats(DATA_PATH / "md-stats.dat")

@pytest.mark.parametrize(
"attr,expected",
(
("rows", 100),
("columns", 18),
),
ids=(
"get rows",
"get_cols",
),
)
def test_props(self, attr, expected):
assert getattr(self.data, attr) == expected

@pytest.mark.parametrize(
"attr,ind,expected",
(
("data", (99, 17), approx(300.0)),
("units", 0, ""),
("units", 17, "K"),
("labels", 0, "# Step"),
("labels", 17, "Target T"),
),
ids=(
"data value",
"Step units",
"Target T units",
"Step label",
"Target T label",
),
)
def test_data_index(self, attr, ind, expected):
assert getattr(self.data, attr)[ind] == expected

@pytest.mark.parametrize(
"ind,expectedcol",
(
("target t", 17),
(17, 17),
(slice(3, 7, 2), (3, 5)),
((1, 3), (1, 3)),
(("target t", "step"), (17, 0)),
(("target t", 0), (17, 0)),
),
ids=(
"str",
"int",
"slice",
"tuple[int]",
"tuple[str]",
"tuple[mixed]",
),
)
def test_getitem(self, ind, expectedcol):
assert (self.data[ind] == self.data.data[:, expectedcol]).all()

def test_repr(self, capsys):
print(self.data)
std_out_err = capsys.readouterr()
assert std_out_err.err == ""
assert "index label units" in std_out_err.out
assert (
f"contains {self.data.columns} timeseries, each with {self.data.rows} elements"
in std_out_err.out
)

0 comments on commit 67881f7

Please sign in to comment.