Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced some assert statements with proper checks #541

Merged
merged 1 commit into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions pde/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,8 @@

else:
raise RuntimeError(
"Multiple data fields were found in the "
"file but no FieldCollection is expected"
"Multiple data fields were found in the file but no "
"`FieldCollection` is expected."
)
return obj

Expand Down Expand Up @@ -684,7 +684,8 @@
else:
raise TypeError("`func` must be string or callable")

assert isinstance(out, FieldBase)
if not isinstance(out, FieldBase):
raise TypeError("`out` must be of type `FieldBase`")

Check warning on line 688 in pde/fields/base.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/base.py#L688

Added line #L688 was not covered by tests
if label:
out.label = label
return out # type: ignore
Expand Down Expand Up @@ -1659,7 +1660,8 @@
rank_b = b.ndim - num_axes
if rank_a < 1 or rank_b < 1:
raise TypeError("Fields in dot product must have rank >= 1")
assert a.shape[rank_a:] == b.shape[rank_b:]
if a.shape[rank_a:] != b.shape[rank_b:]:
raise ValueError("Shapes of fields are not compatible for dot product")

Check warning on line 1664 in pde/fields/base.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/base.py#L1664

Added line #L1664 was not covered by tests

if rank_a == 1 and rank_b == 1: # result is scalar field
return np.einsum("i...,i...->...", a, maybe_conj(b), out=out)
Expand Down
6 changes: 4 additions & 2 deletions pde/fields/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@
field._data_flat = self._data_full[self._slices[i]]

# check whether the field data is based on our data field
assert field.data.shape == field_shape
assert np.may_share_memory(field._data_full, self._data_full)
if field.data.shape != field_shape:
raise RuntimeError("Field shapes have changed!")

Check warning on line 136 in pde/fields/collection.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/collection.py#L136

Added line #L136 was not covered by tests
if not np.may_share_memory(field._data_full, self._data_full):
raise RuntimeError("Spurious copy of data detected!")

Check warning on line 138 in pde/fields/collection.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/collection.py#L138

Added line #L138 was not covered by tests

if labels is not None:
self.labels = labels # type: ignore
Expand Down
6 changes: 4 additions & 2 deletions pde/fields/tensorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@
self._data_full = data_full

# ensure that no copying happend
assert np.may_share_memory(self.data, value)
if not np.may_share_memory(self.data, value):
raise RuntimeError("Spurious copy detected!")

Check warning on line 157 in pde/fields/tensorial.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/tensorial.py#L157

Added line #L157 was not covered by tests

def dot(
self,
Expand Down Expand Up @@ -192,7 +193,8 @@
if out is None:
out = other.__class__(self.grid, dtype=get_common_dtype(self, other))
else:
assert isinstance(out, other.__class__), f"`out` must be {other.__class__}"
if not isinstance(out, other.__class__):
raise TypeError(f"`out` must be of type `{other.__class__}`")

Check warning on line 197 in pde/fields/tensorial.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/tensorial.py#L196-L197

Added lines #L196 - L197 were not covered by tests
self.grid.assert_grid_compatible(out.grid)

# calculate the result
Expand Down
6 changes: 4 additions & 2 deletions pde/fields/vectorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@

data = []
for field in fields:
assert field.grid.compatible_with(grid)
if not field.grid.compatible_with(grid):
raise ValueError("Grids are incompatible")

Check warning on line 78 in pde/fields/vectorial.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/vectorial.py#L78

Added line #L78 was not covered by tests
data.append(field.data)

return cls(grid, data, label=label, dtype=dtype)
Expand Down Expand Up @@ -209,7 +210,8 @@
if out is None:
out = result_type(self.grid, dtype=get_common_dtype(self, other))
else:
assert isinstance(out, result_type), f"`out` must be {result_type}"
if not isinstance(out, result_type):
raise TypeError(f"`out` must be of type `{result_type}`")

Check warning on line 214 in pde/fields/vectorial.py

View check run for this annotation

Codecov / codecov/patch

pde/fields/vectorial.py#L213-L214

Added lines #L213 - L214 were not covered by tests
self.grid.assert_grid_compatible(out.grid)

# calculate the result
Expand Down
3 changes: 2 additions & 1 deletion pde/grids/coordinates/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,5 +268,6 @@

# convert the basis of the vectors to Cartesian
basis = self.basis_rotation(points)
assert basis.shape == (self.dim, self.dim) + shape
if basis.shape != (self.dim, self.dim) + shape:
raise DimensionError("Incompatible dimensions in rotation matrix")

Check warning on line 272 in pde/grids/coordinates/base.py

View check run for this annotation

Codecov / codecov/patch

pde/grids/coordinates/base.py#L272

Added line #L272 was not covered by tests
return np.einsum("j...,ji...->i...", components, basis) # type: ignore
3 changes: 2 additions & 1 deletion pde/pdes/allen_cahn.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
:class:`~pde.fields.ScalarField`:
Scalar field describing the evolution rate of the PDE
"""
assert isinstance(state, ScalarField), "`state` must be ScalarField"
if not isinstance(state, ScalarField):
raise ValueError("`state` must be ScalarField")

Check warning on line 77 in pde/pdes/allen_cahn.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/allen_cahn.py#L77

Added line #L77 was not covered by tests
laplace = state.laplace(bc=self.bc, label="evolution rate", args={"t": t})
return self.interface_width * laplace - state**3 + state # type: ignore

Expand Down
3 changes: 2 additions & 1 deletion pde/pdes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,8 @@
cell_volume: Callable[[int], float] = state.grid.make_cell_volume_compiled(
flat_index=True
)
assert state.dtype == float
if state.dtype != float:
raise TypeError("Noise is only supported for float types")

Check warning on line 371 in pde/pdes/base.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/base.py#L371

Added line #L371 was not covered by tests

if isinstance(state, FieldCollection):
# different noise strengths, assuming one for each field
Expand Down
3 changes: 2 additions & 1 deletion pde/pdes/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@
:class:`~pde.fields.ScalarField`:
Scalar field describing the evolution rate of the PDE
"""
assert isinstance(state, ScalarField), "`state` must be ScalarField"
if not isinstance(state, ScalarField):
raise ValueError("`state` must be ScalarField")

Check warning on line 88 in pde/pdes/diffusion.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/diffusion.py#L88

Added line #L88 was not covered by tests
laplace = state.laplace(bc=self.bc, label="evolution rate", args={"t": t})
return self.diffusivity * laplace # type: ignore

Expand Down
3 changes: 2 additions & 1 deletion pde/pdes/kpz_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@
:class:`~pde.fields.ScalarField`:
Scalar field describing the evolution rate of the PDE
"""
assert isinstance(state, ScalarField), "`state` must be ScalarField"
if not isinstance(state, ScalarField):
raise ValueError("`state` must be ScalarField")

Check warning on line 96 in pde/pdes/kpz_interface.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/kpz_interface.py#L96

Added line #L96 was not covered by tests
result = self.nu * state.laplace(bc=self.bc, args={"t": t})
result += self.lmbda * state.gradient_squared(bc=self.bc, args={"t": t})
result.label = "evolution rate"
Expand Down
3 changes: 2 additions & 1 deletion pde/pdes/kuramoto_sivashinsky.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@
:class:`~pde.fields.ScalarField`:
Scalar field describing the evolution rate of the PDE
"""
assert isinstance(state, ScalarField), "`state` must be ScalarField"
if not isinstance(state, ScalarField):
raise ValueError("`state` must be ScalarField")

Check warning on line 99 in pde/pdes/kuramoto_sivashinsky.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/kuramoto_sivashinsky.py#L99

Added line #L99 was not covered by tests
state_lap = state.laplace(bc=self.bc, args={"t": t})
result = (
-self.nu * state_lap.laplace(bc=self.bc_lap, args={"t": t})
Expand Down
3 changes: 2 additions & 1 deletion pde/pdes/swift_hohenberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@
:class:`~pde.fields.ScalarField`:
Scalar field describing the evolution rate of the PDE
"""
assert isinstance(state, ScalarField), "`state` must be ScalarField"
if not isinstance(state, ScalarField):
raise ValueError("`state` must be ScalarField")

Check warning on line 98 in pde/pdes/swift_hohenberg.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/swift_hohenberg.py#L98

Added line #L98 was not covered by tests
state_laplace = state.laplace(bc=self.bc, args={"t": t})
state_laplace2 = state_laplace.laplace(bc=self.bc_lap, args={"t": t})

Expand Down
6 changes: 4 additions & 2 deletions pde/pdes/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@
:class:`~pde.fields.FieldCollection`:
Fields describing the evolution rates of the PDE
"""
assert isinstance(state, FieldCollection), "`state` must be FieldCollection"
assert len(state) == 2, "`state` must contain two fields"
if not isinstance(state, FieldCollection):
raise ValueError("`state` must be FieldCollection")

Check warning on line 95 in pde/pdes/wave.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/wave.py#L95

Added line #L95 was not covered by tests
if len(state) != 2:
raise ValueError("`state` must contain two fields")

Check warning on line 97 in pde/pdes/wave.py

View check run for this annotation

Codecov / codecov/patch

pde/pdes/wave.py#L97

Added line #L97 was not covered by tests
u, v = state
u_t = v.copy()
v_t = self.speed**2 * u.laplace(self.bc, args={"t": t}) # type: ignore
Expand Down
3 changes: 2 additions & 1 deletion pde/storage/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,8 @@
data (:class:`~numpy.ndarray`): The actual data
time (float, optional): The time point associated with the data
"""
assert data.shape == self.data_shape, f"Data must have shape {self.data_shape}"
if data.shape != self.data_shape:
raise ValueError(f"Data must have shape {self.data_shape}")

Check warning on line 206 in pde/storage/memory.py

View check run for this annotation

Codecov / codecov/patch

pde/storage/memory.py#L206

Added line #L206 was not covered by tests
self.data.append(np.array(data)) # store copy of the data
self.times.append(time)

Expand Down
3 changes: 2 additions & 1 deletion pde/tools/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,7 +876,8 @@
Whether the compiled function expects all arguments as a single array
or whether they are supplied individually.
"""
assert isinstance(self._sympy_expr, sympy.Array), "Expression must be an array"
if not isinstance(self._sympy_expr, sympy.Array):
raise TypeError("Expression must be an array")

Check warning on line 880 in pde/tools/expressions.py

View check run for this annotation

Codecov / codecov/patch

pde/tools/expressions.py#L880

Added line #L880 was not covered by tests
variables = ", ".join(v for v in self.vars)
shape = self._sympy_expr.shape

Expand Down
3 changes: 2 additions & 1 deletion pde/trackers/interrupts.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class FixedInterrupts(InterruptsBase):

def __init__(self, interrupts: np.ndarray | Sequence[float]):
self.interrupts = np.atleast_1d(interrupts)
assert self.interrupts.ndim == 1
if self.interrupts.ndim != 1:
raise ValueError("`interrupts` must be a 1d sequence")

def __repr__(self):
return f"{self.__class__.__name__}(interrupts={self.interrupts})"
Expand Down
9 changes: 6 additions & 3 deletions pde/visualization/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,8 @@
:class:`~pde.visualization.plotting.ScalarFieldPlot`
"""
fields = storage[0]
assert isinstance(fields, FieldBase), "Storage must contain fields"
if not isinstance(fields, FieldBase):
raise RuntimeError("Storage must contain fields")

Check warning on line 209 in pde/visualization/plotting.py

View check run for this annotation

Codecov / codecov/patch

pde/visualization/plotting.py#L209

Added line #L209 was not covered by tests

# prepare the data that needs to be plotted
quantities = cls._prepare_quantities(fields, quantities=quantities, scale=scale)
Expand Down Expand Up @@ -393,7 +394,8 @@
The title of this view. If `None`, the current title is not
changed.
"""
assert isinstance(fields, FieldBase), "Fields must inherit from FieldBase"
if not isinstance(fields, FieldBase):
raise TypeError("`fields` must be of type FieldBase")

Check warning on line 398 in pde/visualization/plotting.py

View check run for this annotation

Codecov / codecov/patch

pde/visualization/plotting.py#L398

Added line #L398 was not covered by tests

if title:
self.sup_title.set_text(title)
Expand Down Expand Up @@ -520,7 +522,8 @@
"""
if quantities is None:
fields = storage[0]
assert isinstance(fields, FieldBase), "Storage must contain fields"
if not isinstance(fields, FieldBase):
raise TypeError("`fields` must be of type FieldBase")

Check warning on line 526 in pde/visualization/plotting.py

View check run for this annotation

Codecov / codecov/patch

pde/visualization/plotting.py#L526

Added line #L526 was not covered by tests
if fields.label:
label_base = fields.label
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/trackers/test_interrupts.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,5 +113,5 @@ def test_interrupt_fixed():
ival = FixedInterrupts(1)
assert ival.initialize(0) == pytest.approx(1)
assert np.isinf(ival.next(0))
with pytest.raises(AssertionError):
with pytest.raises(ValueError):
ival = FixedInterrupts([[1]])
Loading