Skip to content

Commit 240fb89

Browse files
authored
Enable runtime type checking in tests with typeguard (#478)
Re-enables the `typeguard` job that I disabled 2 years ago in #415.
1 parent d25e7e7 commit 240fb89

11 files changed

+62
-71
lines changed

.github/workflows/typeguard.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
name: typeguard
22

3-
# TODO: enable this once typeguard=4 is released and issues are fixed.
4-
# on:
5-
# - push
3+
on:
4+
pull_request:
5+
push:
6+
branches: [main]
67

78
jobs:
89
typeguard:

adaptive/_types.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Only used for static type checkers, should only be imported in `if TYPE_CHECKING` block
2+
# Workaround described in https://github.com/agronholm/typeguard/issues/456
3+
4+
import concurrent.futures as concurrent
5+
from typing import TypeAlias
6+
7+
import distributed
8+
import ipyparallel
9+
import loky
10+
import mpi4py.futures
11+
12+
from adaptive.utils import SequentialExecutor
13+
14+
ExecutorTypes: TypeAlias = (
15+
concurrent.ProcessPoolExecutor
16+
| concurrent.ThreadPoolExecutor
17+
| SequentialExecutor
18+
| loky.reusable_executor._ReusablePoolExecutor
19+
| distributed.Client
20+
| distributed.cfexecutor.ClientExecutor
21+
| mpi4py.futures.MPIPoolExecutor
22+
| ipyparallel.Client
23+
| ipyparallel.client.view.ViewExecutor
24+
)

adaptive/learner/balancing_learner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -269,17 +269,17 @@ def ask(
269269
return self._ask_and_tell(n)
270270

271271
def tell(self, x: tuple[Int, Any], y: Any) -> None:
272-
index, x = x
272+
index, x_ = x
273273
self._ask_cache.pop(index, None)
274274
self._loss.pop(index, None)
275275
self._pending_loss.pop(index, None)
276-
self.learners[index].tell(x, y)
276+
self.learners[index].tell(x_, y)
277277

278278
def tell_pending(self, x: tuple[Int, Any]) -> None:
279-
index, x = x
279+
index, x_ = x
280280
self._ask_cache.pop(index, None)
281281
self._loss.pop(index, None)
282-
self.learners[index].tell_pending(x)
282+
self.learners[index].tell_pending(x_)
283283

284284
def _losses(self, real: bool = True) -> list[float]:
285285
losses = []

adaptive/learner/learner1D.py

Lines changed: 22 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import math
66
from collections.abc import Callable, Sequence
77
from copy import copy, deepcopy
8-
from typing import TYPE_CHECKING, Any, TypeAlias
8+
from typing import Any, TypeAlias
99

1010
import cloudpickle
1111
import numpy as np
@@ -31,25 +31,21 @@
3131
except ModuleNotFoundError:
3232
with_pandas = False
3333

34-
if TYPE_CHECKING:
35-
# -- types --
3634

37-
# Commonly used types
38-
Interval: TypeAlias = tuple[float, float] | tuple[float, float, int]
39-
NeighborsType: TypeAlias = SortedDict[float, list[float | None]]
35+
# Commonly used types
36+
Interval: TypeAlias = tuple[float, float] | tuple[float, float, int]
37+
NeighborsType: TypeAlias = SortedDict[float, list[float | None]]
4038

41-
# Types for loss_per_interval functions
42-
XsType0: TypeAlias = tuple[float, float]
43-
YsType0: TypeAlias = tuple[float, float] | tuple[np.ndarray, np.ndarray]
44-
XsType1: TypeAlias = tuple[float | None, float | None, float | None, float | None]
45-
YsType1: TypeAlias = (
46-
tuple[float | None, float | None, float | None, float | None]
47-
| tuple[
48-
np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None
49-
]
50-
)
51-
XsTypeN: TypeAlias = tuple[float | None, ...]
52-
YsTypeN: TypeAlias = tuple[float | None, ...] | tuple[np.ndarray | None, ...]
39+
# Types for loss_per_interval functions
40+
XsType0: TypeAlias = tuple[float, float]
41+
YsType0: TypeAlias = tuple[float, float] | tuple[np.ndarray, np.ndarray]
42+
XsType1: TypeAlias = tuple[float | None, float | None, float | None, float | None]
43+
YsType1: TypeAlias = (
44+
tuple[float | None, float | None, float | None, float | None]
45+
| tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None]
46+
)
47+
XsTypeN: TypeAlias = tuple[float | None, ...]
48+
YsTypeN: TypeAlias = tuple[float | None, ...] | tuple[np.ndarray | None, ...]
5349

5450

5551
__all__ = [
@@ -110,18 +106,18 @@ def abs_min_log_loss(xs: XsType0, ys: YsType0) -> Float:
110106
@uses_nth_neighbors(1)
111107
def triangle_loss(xs: XsType1, ys: YsType1) -> Float:
112108
assert len(xs) == 4
113-
xs = [x for x in xs if x is not None] # type: ignore[assignment]
114-
ys = [y for y in ys if y is not None] # type: ignore[assignment]
109+
x = [x for x in xs if x is not None]
110+
y = [y for y in ys if y is not None]
115111

116-
if len(xs) == 2: # we do not have enough points for a triangle
117-
return xs[1] - xs[0] # type: ignore[operator]
112+
if len(x) == 2: # we do not have enough points for a triangle
113+
return x[1] - x[0] # type: ignore[operator]
118114

119-
N = len(xs) - 2 # number of constructed triangles
120-
if isinstance(ys[0], collections.abc.Iterable):
121-
pts = [(x, *y) for x, y in zip(xs, ys)] # type: ignore[misc]
115+
N = len(x) - 2 # number of constructed triangles
116+
if isinstance(y[0], collections.abc.Iterable):
117+
pts = [(x, *y) for x, y in zip(x, y)] # type: ignore[misc]
122118
vol = simplex_volume_in_embedding
123119
else:
124-
pts = list(zip(xs, ys))
120+
pts = list(zip(x, y))
125121
vol = volume
126122
return sum(vol(pts[i : i + 3]) for i in range(N)) / N
127123

adaptive/learner/sequence_learner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,13 @@ def remove_unfinished(self) -> None:
134134
self.pending_points = set()
135135

136136
def tell(self, point: PointType, value: Any) -> None:
137-
index, point = point
137+
index, _ = point
138138
self.data[index] = value
139139
self.pending_points.discard(index)
140140
self._to_do_indices.discard(index)
141141

142142
def tell_pending(self, point: PointType) -> None:
143-
index, point = point
143+
index, _ = point
144144
self.pending_points.add(index)
145145
self._to_do_indices.discard(index)
146146

adaptive/runner.py

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,32 +24,18 @@
2424
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
2525
from adaptive.utils import SequentialExecutor
2626

27+
FutureTypes: TypeAlias = concurrent.Future | asyncio.Future
28+
2729
if TYPE_CHECKING:
2830
import holoviews
2931

32+
from ._types import ExecutorTypes
33+
3034

3135
with_ipyparallel = find_spec("ipyparallel") is not None
3236
with_distributed = find_spec("distributed") is not None
3337
with_mpi4py = find_spec("mpi4py") is not None
3438

35-
if TYPE_CHECKING:
36-
import distributed
37-
import ipyparallel
38-
import mpi4py.futures
39-
40-
ExecutorTypes: TypeAlias = (
41-
concurrent.ProcessPoolExecutor
42-
| concurrent.ThreadPoolExecutor
43-
| SequentialExecutor
44-
| loky.reusable_executor._ReusablePoolExecutor
45-
| distributed.Client
46-
| distributed.cfexecutor.ClientExecutor
47-
| mpi4py.futures.MPIPoolExecutor
48-
| ipyparallel.Client
49-
| ipyparallel.client.view.ViewExecutor
50-
)
51-
FutureTypes: TypeAlias = concurrent.Future | asyncio.Future
52-
5339

5440
with suppress(ModuleNotFoundError):
5541
import uvloop
@@ -906,7 +892,7 @@ def _info_text(runner, separator: str = "\n"):
906892
info.append(("# of samples", runner.learner.nsamples))
907893

908894
with suppress(Exception):
909-
info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}'))
895+
info.append(("latest loss", f"{runner.learner._cache['loss']:.3f}"))
910896

911897
width = 30
912898
formatted_info = [f"{k}: {v}".ljust(width) for i, (k, v) in enumerate(info)]

adaptive/tests/test_average_learner.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
import random
2-
from typing import TYPE_CHECKING
32

43
import flaky
54
import numpy as np
65

76
from adaptive.learner import AverageLearner
87
from adaptive.runner import simple
98

10-
if TYPE_CHECKING:
11-
pass
12-
139

1410
def f_unused(seed):
1511
raise NotImplementedError("This function shouldn't be used.")

adaptive/tests/test_average_learner1d.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from itertools import chain
2-
from typing import TYPE_CHECKING
32

43
import numpy as np
54

@@ -10,9 +9,6 @@
109
simple_run,
1110
)
1211

13-
if TYPE_CHECKING:
14-
pass
15-
1612

1713
def almost_equal_dicts(a, b):
1814
assert a.keys() == b.keys()

adaptive/tests/test_balancing_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_distribute_first_points_over_learners(strategy):
3535
learner = BalancingLearner(learners, strategy=strategy)
3636

3737
points = learner.ask(initial_points)[0]
38-
learner.tell_many(points, points)
38+
learner.tell_many(points, [x for i, x in points])
3939

4040
points, _ = learner.ask(100)
4141
i_learner, xs = zip(*points)

adaptive/tests/test_learner1d.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import random
44
import time
5-
from typing import TYPE_CHECKING
65

76
import flaky
87
import numpy as np
@@ -11,9 +10,6 @@
1110
from adaptive.learner.learner1D import curvature_loss_function
1211
from adaptive.runner import BlockingRunner, simple
1312

14-
if TYPE_CHECKING:
15-
pass
16-
1713

1814
def flat_middle(x):
1915
x *= 1e7

0 commit comments

Comments
 (0)