Skip to content

Commit

Permalink
Make tests pytest parametrized
Browse files Browse the repository at this point in the history
  • Loading branch information
oerc0122 committed Apr 25, 2024
1 parent 4b9515d commit 33a987b
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 42 deletions.
4 changes: 3 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
numpydoc_validation_exclude = {
r"\.__weakref__$",
r"\.__repr__$",
r"janus_core\.janus_types",
r"janus_core\.helpers\.janus_types",
}
numpydoc_class_members_toctree = False

Expand Down Expand Up @@ -192,4 +192,6 @@
nitpick_ignore = [
("py:class", "Logger"),
("py:class", "numpy.float64"),
("py:class", "ellipsis"),
("py:class", "janus_core.helpers.stats.T"),
]
103 changes: 88 additions & 15 deletions janus_core/helpers/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,23 @@
Module that reads the md stats output timeseries.
"""

from collections.abc import Iterator
from functools import singledispatchmethod
import re
from types import EllipsisType
from typing import TypeVar

from numpy import float64, genfromtxt, zeros
from numpy.typing import NDArray

from janus_core.helpers.janus_types import PathLike

try: # Python >=3.10
from types import EllipsisType
except ImportError:
EllipsisType = type(...)

T = TypeVar("T")


class Stats:
"""
Expand Down Expand Up @@ -41,29 +49,82 @@ def __init__(self, source: PathLike) -> None:
self.read()

@singledispatchmethod
def __getitem__(self, ind):
def _getind(self, lab: T) -> T:
"""
Convert an index label from str to int if present in labels.
Otherwise return the input.
Parameters
----------
lab : str
Label to find.
Returns
-------
int
Index of label in self or input if not string.
Raises
------
IndexError
Label not found in labels.
"""
return lab

@_getind.register
def _(self, lab: str) -> int: # numpydoc ignore=GL08
# Case-insensitive fuzzy match, only has to be `in` the labels
index = next(
(
index
for index, label in enumerate(self.labels)
if lab.lower() in label.lower()
),
None,
)
if index is None:
raise IndexError(f"{lab} not found in labels")
return index

@singledispatchmethod
def __getitem__(self, ind) -> NDArray[float64]:
"""
Get member of stats data by label or index.
Parameters
----------
ind : Any
Index or label to find.
Returns
-------
NDArray[float64]
Columns of data by label.
Raises
------
IndexError
Invalid index type or label not found in labels.
"""
raise IndexError(f"Unknown index {ind}")

@__getitem__.register(int)
@__getitem__.register(slice)
@__getitem__.register(EllipsisType)
def _(self, ind):
def _(self, ind) -> NDArray[float64]: # numpydoc ignore=GL08
return self.data[:, ind]

@__getitem__.register(list)
@__getitem__.register(tuple)
def _(self, ind):
return self.data[ind]
def _(self, ind) -> NDArray[float64]: # numpydoc ignore=GL08
ind = list(map(self._getind, ind))
return self.data[:, ind]

@__getitem__.register(str)
def _(self, ind):
# Case-insensitive fuzzy match, only has to be `in` the labels
index = next((index
for index, label in enumerate(self.labels)
if ind.lower() in label.lower()),
None)
if index is None:
raise IndexError(f"{ind} not found in labels")
return self[index]
def _(self, ind) -> NDArray[float64]: # numpydoc ignore=GL08
ind = self._getind(ind)
return self[ind]

@property
def rows(self) -> int:
Expand Down Expand Up @@ -137,6 +198,18 @@ def data(self) -> NDArray[float64]:
"""
return self._data

@property
def data_tags(self) -> Iterator[tuple[str, str]]:
"""
Return the labels and their units together.
Returns
-------
Iterator[tuple[str, str]]
Zipped labels and units.
"""
return zip(self.labels, self.units)

def read(self) -> None:
"""
Read MD stats and store them in `data`.
Expand All @@ -161,6 +234,6 @@ def __repr__(self) -> str:

header = f"contains {self.columns} timeseries, each with {self.rows} elements"
header += "\nindex label units"
for index, (label, unit) in enumerate(zip(self.labels, self.units)):
for index, (label, unit) in enumerate(self.data_tags):
header += f"\n{index} {label} {unit}"
return header
100 changes: 74 additions & 26 deletions tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,85 @@

from pathlib import Path

import pytest
from pytest import approx

from janus_core.helpers.stats import Stats

DATA_PATH = Path(__file__).parent / "data"


def test_stats(capsys):
"""Test reading md stats."""
data_path = DATA_PATH / "md-stats.dat"

stat_data = Stats(data_path)

assert stat_data.rows == 100
assert stat_data.columns == 18
assert stat_data.data[99, 17] == approx(300.0)
assert stat_data.units[0] == ""
assert stat_data.units[17] == "K"
assert stat_data.labels[0] == "# Step"
assert stat_data.labels[17] == "Target T"

# Check getitem form
assert (stat_data["target t"] == stat_data.data[:, 17]).all()
assert (stat_data[:, 17] == stat_data.data[:, 17]).all()
assert (stat_data[17] == stat_data.data[:, 17]).all()

print(stat_data)
std_out_err = capsys.readouterr()
assert std_out_err.err == ""
assert "index label units" in std_out_err.out
assert (
f"contains {stat_data.columns} timeseries, each with {stat_data.rows} elements"
in std_out_err.out
class TestStats:
"""Tests for the stats type."""

data = Stats(DATA_PATH / "md-stats.dat")

@pytest.mark.parametrize(
"attr,expected",
(
("rows", 100),
("columns", 18),
),
ids=(
"get rows",
"get_cols",
),
)
def test_props(self, attr, expected):
"""Test props are being set correctly."""
assert getattr(self.data, attr) == expected

@pytest.mark.parametrize(
"attr,ind,expected",
(
("data", (99, 17), approx(300.0)),
("units", 0, ""),
("units", 17, "K"),
("labels", 0, "# Step"),
("labels", 17, "Target T"),
),
ids=(
"data value",
"Step units",
"Target T units",
"Step label",
"Target T label",
),
)
def test_data_index(self, attr, ind, expected):
"""Test data indexing working correctly."""
assert getattr(self.data, attr)[ind] == expected

@pytest.mark.parametrize(
"ind,expectedcol",
(
("target t", 17),
(17, 17),
(slice(3, 7, 2), (3, 5)),
((1, 3), (1, 3)),
(("target t", "step"), (17, 0)),
(("target t", 0), (17, 0)),
),
ids=(
"str",
"int",
"slice",
"tuple[int]",
"tuple[str]",
"tuple[mixed]",
),
)
def test_getitem(self, ind, expectedcol):
"""Test getitem indexing working correctly."""
assert (self.data[ind] == self.data.data[:, expectedcol]).all()

def test_repr(self, capsys):
"""Test repr working correctly."""
print(self.data)
std_out_err = capsys.readouterr()
assert std_out_err.err == ""
assert "index label units" in std_out_err.out
assert (
f"contains {self.data.columns} timeseries, "
f"each with {self.data.rows} elements" in std_out_err.out
)

0 comments on commit 33a987b

Please sign in to comment.