diff --git a/src/matplotgl/axes.py b/src/matplotgl/axes.py index 5baa477..a3e1479 100644 --- a/src/matplotgl/axes.py +++ b/src/matplotgl/axes.py @@ -15,6 +15,14 @@ from .widgets import ClickableHTML +def min_with_none(a, b): + return a if b is None else min(a, b) + + +def max_with_none(a, b): + return a if b is None else max(a, b) + + class Axes(ipw.GridBox): def __init__(self, *, ax: MplAxes, figure=None) -> None: self.background_color = "#ffffff" @@ -290,35 +298,43 @@ def height(self, h): # self._margins["rightspine"].height = h def autoscale(self): - xmin = np.inf - xmax = -np.inf - ymin = np.inf - ymax = -np.inf + xmin = None + xmax = None + ymin = None + ymax = None for artist in self._artists: lims = artist.get_bbox() - xmin = min(lims["left"], xmin) - xmax = max(lims["right"], xmax) - ymin = min(lims["bottom"], ymin) - ymax = max(lims["top"], ymax) - self._xmin = xmin - self._xmax = xmax - self._ymin = ymin - self._ymax = ymax - - # self._background_mesh.geometry = p3.BoxGeometry( - # width=2 * (self._xmax - self._xmin), - # height=2 * (self._ymax - self._ymin), - # widthSegments=1, - # heightSegments=1, - # ) + xmin = min_with_none(lims["left"], xmin) + xmax = max_with_none(lims["right"], xmax) + ymin = min_with_none(lims["bottom"], ymin) + ymax = max_with_none(lims["top"], ymax) + self._xmin = ( + xmin + if xmin is not None + else (0.0 if self.get_xscale() == "linear" else 1.0) + ) + self._xmax = ( + xmax + if xmax is not None + else (1.0 if self.get_xscale() == "linear" else 10.0) + ) + self._ymin = ( + ymin + if ymin is not None + else (0.0 if self.get_yscale() == "linear" else 1.0) + ) + self._ymax = ( + ymax + if ymax is not None + else (1.0 if self.get_yscale() == "linear" else 10.0) + ) + self._background_mesh.geometry = p3.PlaneGeometry( width=2 * (self._xmax - self._xmin), height=2 * (self._ymax - self._ymin), widthSegments=1, heightSegments=1, ) - # self._background_mesh.geometry.width = 2 * (self._xmax - self._xmin) - # self._background_mesh.geometry.height = 2 * (self._ymax - self._ymin) self._background_mesh.position = [ 0.5 * (self._xmin + self._xmax), @@ -523,6 +539,10 @@ def get_xscale(self): return self._ax.get_xscale() def set_xscale(self, scale): + if scale not in ("linear", "log"): + raise ValueError("Scale must be 'linear' or 'log'") + if scale == self.get_xscale(): + return self._ax.set_xscale(scale) for artist in self._artists: artist._set_xscale(scale) @@ -533,6 +553,10 @@ def get_yscale(self): return self._ax.get_yscale() def set_yscale(self, scale): + if scale not in ("linear", "log"): + raise ValueError("Scale must be 'linear' or 'log'") + if scale == self.get_yscale(): + return self._ax.set_yscale(scale) for artist in self._artists: artist._set_yscale(scale) @@ -661,13 +685,35 @@ def get_title(self): def plot(self, *args, color=None, **kwargs): if color is None: color = f"C{len(self.lines)}" - line = Line(*args, color=color, **kwargs) + line = Line( + *args, + color=color, + xscale=self.get_xscale(), + yscale=self.get_yscale(), + **kwargs, + ) line.axes = self self.lines.append(line) self.add_artist(line) self.autoscale() return line + def semilogx(self, *args, **kwargs): + out = self.plot(*args, **kwargs) + self.set_xscale("log") + return out + + def semilogy(self, *args, **kwargs): + out = self.plot(*args, **kwargs) + self.set_yscale("log") + return out + + def loglog(self, *args, **kwargs): + out = self.plot(*args, **kwargs) + self.set_xscale("log") + self.set_yscale("log") + return out + def scatter(self, *args, c=None, **kwargs): if c is None: c = f"C{len(self.collections)}" diff --git a/src/matplotgl/line.py b/src/matplotgl/line.py index a0a912a..34efe7d 100644 --- a/src/matplotgl/line.py +++ b/src/matplotgl/line.py @@ -10,19 +10,27 @@ class Line: - def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0): + def __init__( + self, + x, + y, + fmt="-", + color="C0", + ls="solid", + lw=1, + ms=5, + zorder=0, + xscale="linear", + yscale="linear", + ): self.axes = None - self._xscale = "linear" - self._yscale = "linear" + self._xscale = xscale + self._yscale = yscale self._x = np.asarray(x) self._y = np.asarray(y) self._zorder = zorder - self._line_geometry = p3.LineGeometry( - positions=np.array( - [self._x, self._y, np.full_like(self._x, self._zorder - 50)], - dtype="float32", - ).T - ) + pos = self._make_positions() + self._line_geometry = p3.LineGeometry(positions=pos) self._color = mplc.to_hex(color) self._line = None @@ -39,16 +47,7 @@ def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0): if "o" in fmt: self._vertices_geometry = p3.BufferGeometry( attributes={ - "position": p3.BufferAttribute( - array=np.array( - [ - self._x, - self._y, - np.full_like(self._x, self._zorder - 50), - ], - dtype="float32", - ).T - ), + "position": p3.BufferAttribute(array=pos), } ) self._vertices_material = p3.PointsMaterial(color=self._color, size=ms) @@ -70,14 +69,18 @@ def get(self): out.append(self._vertices) return p3.Group(children=out) if len(out) > 1 else out[0] - def _update(self): + def _make_positions(self): with warnings.catch_warnings(category=RuntimeWarning, action="ignore"): xx = self._x if self._xscale == "linear" else np.log10(self._x) yy = self._y if self._yscale == "linear" else np.log10(self._y) pos = np.array( - [xx, yy, np.full_like(xx, self._zorder - 50)], + [xx, yy, np.full_like(xx, self._zorder)], dtype="float32", ).T + return pos + + def _update(self): + pos = self._make_positions() if self._line is not None: self._line_geometry.positions = pos if self._vertices is not None: diff --git a/tests/plot_test.py b/tests/plot_test.py index ef41ae7..a8c93ce 100644 --- a/tests/plot_test.py +++ b/tests/plot_test.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: BSD-3-Clause import numpy as np +import pytest import matplotgl.pyplot as plt @@ -60,3 +61,90 @@ def test_imshow(): im = ax.images[0] assert np.allclose(im._array, data) assert im.get_extent() == [0, 10, 0, 5] + + +def test_set_xscale_log(): + _, ax = plt.subplots() + x = np.arange(50.0) + y = np.sin(0.2 * x) + + ax.plot(x, y, lw=2) + ax.set_xscale('log') + + assert ax.get_xscale() == 'log' + + +def test_set_yscale_log(): + _, ax = plt.subplots() + x = np.arange(50.0) + y = np.sin(0.2 * x) + + ax.plot(x, y, lw=2) + ax.set_yscale('log') + + assert ax.get_yscale() == 'log' + + +def test_set_xscale_invalid(): + _, ax = plt.subplots() + with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"): + ax.set_xscale('invalid_scale') + + +def test_set_yscale_invalid(): + _, ax = plt.subplots() + with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"): + ax.set_yscale('invalid_scale') + + +def test_set_xscale_log_before_plot(): + _, ax = plt.subplots() + x = np.arange(50.0) + y = np.sin(0.2 * x) + + ax.set_xscale('log') + ax.plot(x, y, lw=2) + + assert ax.get_xscale() == 'log' + + +def test_set_yscale_log_before_plot(): + _, ax = plt.subplots() + x = np.arange(50.0) + y = np.sin(0.2 * x) + + ax.set_yscale('log') + ax.plot(x, y, lw=2) + + assert ax.get_yscale() == 'log' + + +def test_semilogx(): + _, ax = plt.subplots() + x = np.arange(1.0, 50.0) + y = np.sin(0.2 * x) + + ax.semilogx(x, y, lw=2) + + assert ax.get_xscale() == 'log' + + +def test_semilogy(): + _, ax = plt.subplots() + x = np.arange(50.0) + y = np.exp(0.1 * x) + + ax.semilogy(x, y, lw=2) + + assert ax.get_yscale() == 'log' + + +def test_loglog(): + _, ax = plt.subplots() + x = np.arange(1.0, 50.0) + y = np.exp(0.1 * x) + + ax.loglog(x, y, lw=2) + + assert ax.get_xscale() == 'log' + assert ax.get_yscale() == 'log'