Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 68 additions & 22 deletions src/matplotgl/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@
from .widgets import ClickableHTML


def min_with_none(a, b):
return a if b is None else min(a, b)


def max_with_none(a, b):
return a if b is None else max(a, b)


class Axes(ipw.GridBox):
def __init__(self, *, ax: MplAxes, figure=None) -> None:
self.background_color = "#ffffff"
Expand Down Expand Up @@ -290,35 +298,43 @@ def height(self, h):
# self._margins["rightspine"].height = h

def autoscale(self):
xmin = np.inf
xmax = -np.inf
ymin = np.inf
ymax = -np.inf
xmin = None
xmax = None
ymin = None
ymax = None
for artist in self._artists:
lims = artist.get_bbox()
xmin = min(lims["left"], xmin)
xmax = max(lims["right"], xmax)
ymin = min(lims["bottom"], ymin)
ymax = max(lims["top"], ymax)
self._xmin = xmin
self._xmax = xmax
self._ymin = ymin
self._ymax = ymax

# self._background_mesh.geometry = p3.BoxGeometry(
# width=2 * (self._xmax - self._xmin),
# height=2 * (self._ymax - self._ymin),
# widthSegments=1,
# heightSegments=1,
# )
xmin = min_with_none(lims["left"], xmin)
xmax = max_with_none(lims["right"], xmax)
ymin = min_with_none(lims["bottom"], ymin)
ymax = max_with_none(lims["top"], ymax)
self._xmin = (
xmin
if xmin is not None
else (0.0 if self.get_xscale() == "linear" else 1.0)
)
self._xmax = (
xmax
if xmax is not None
else (1.0 if self.get_xscale() == "linear" else 10.0)
)
self._ymin = (
ymin
if ymin is not None
else (0.0 if self.get_yscale() == "linear" else 1.0)
)
self._ymax = (
ymax
if ymax is not None
else (1.0 if self.get_yscale() == "linear" else 10.0)
)

self._background_mesh.geometry = p3.PlaneGeometry(
width=2 * (self._xmax - self._xmin),
height=2 * (self._ymax - self._ymin),
widthSegments=1,
heightSegments=1,
)
# self._background_mesh.geometry.width = 2 * (self._xmax - self._xmin)
# self._background_mesh.geometry.height = 2 * (self._ymax - self._ymin)

self._background_mesh.position = [
0.5 * (self._xmin + self._xmax),
Expand Down Expand Up @@ -523,6 +539,10 @@ def get_xscale(self):
return self._ax.get_xscale()

def set_xscale(self, scale):
if scale not in ("linear", "log"):
raise ValueError("Scale must be 'linear' or 'log'")
if scale == self.get_xscale():
return
self._ax.set_xscale(scale)
for artist in self._artists:
artist._set_xscale(scale)
Expand All @@ -533,6 +553,10 @@ def get_yscale(self):
return self._ax.get_yscale()

def set_yscale(self, scale):
if scale not in ("linear", "log"):
raise ValueError("Scale must be 'linear' or 'log'")
if scale == self.get_yscale():
return
self._ax.set_yscale(scale)
for artist in self._artists:
artist._set_yscale(scale)
Expand Down Expand Up @@ -661,13 +685,35 @@ def get_title(self):
def plot(self, *args, color=None, **kwargs):
if color is None:
color = f"C{len(self.lines)}"
line = Line(*args, color=color, **kwargs)
line = Line(
*args,
color=color,
xscale=self.get_xscale(),
yscale=self.get_yscale(),
**kwargs,
)
line.axes = self
self.lines.append(line)
self.add_artist(line)
self.autoscale()
return line

def semilogx(self, *args, **kwargs):
out = self.plot(*args, **kwargs)
self.set_xscale("log")
return out

def semilogy(self, *args, **kwargs):
out = self.plot(*args, **kwargs)
self.set_yscale("log")
return out

def loglog(self, *args, **kwargs):
out = self.plot(*args, **kwargs)
self.set_xscale("log")
self.set_yscale("log")
return out

def scatter(self, *args, c=None, **kwargs):
if c is None:
c = f"C{len(self.collections)}"
Expand Down
45 changes: 24 additions & 21 deletions src/matplotgl/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,27 @@


class Line:
def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0):
def __init__(
self,
x,
y,
fmt="-",
color="C0",
ls="solid",
lw=1,
ms=5,
zorder=0,
xscale="linear",
yscale="linear",
):
self.axes = None
self._xscale = "linear"
self._yscale = "linear"
self._xscale = xscale
self._yscale = yscale
self._x = np.asarray(x)
self._y = np.asarray(y)
self._zorder = zorder
self._line_geometry = p3.LineGeometry(
positions=np.array(
[self._x, self._y, np.full_like(self._x, self._zorder - 50)],
dtype="float32",
).T
)
pos = self._make_positions()
self._line_geometry = p3.LineGeometry(positions=pos)

self._color = mplc.to_hex(color)
self._line = None
Expand All @@ -39,16 +47,7 @@ def __init__(self, x, y, fmt="-", color="C0", ls="solid", lw=1, ms=5, zorder=0):
if "o" in fmt:
self._vertices_geometry = p3.BufferGeometry(
attributes={
"position": p3.BufferAttribute(
array=np.array(
[
self._x,
self._y,
np.full_like(self._x, self._zorder - 50),
],
dtype="float32",
).T
),
"position": p3.BufferAttribute(array=pos),
}
)
self._vertices_material = p3.PointsMaterial(color=self._color, size=ms)
Expand All @@ -70,14 +69,18 @@ def get(self):
out.append(self._vertices)
return p3.Group(children=out) if len(out) > 1 else out[0]

def _update(self):
def _make_positions(self):
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
xx = self._x if self._xscale == "linear" else np.log10(self._x)
yy = self._y if self._yscale == "linear" else np.log10(self._y)
pos = np.array(
[xx, yy, np.full_like(xx, self._zorder - 50)],
[xx, yy, np.full_like(xx, self._zorder)],
dtype="float32",
).T
return pos

def _update(self):
pos = self._make_positions()
if self._line is not None:
self._line_geometry.positions = pos
if self._vertices is not None:
Expand Down
88 changes: 88 additions & 0 deletions tests/plot_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import numpy as np
import pytest

import matplotgl.pyplot as plt

Expand Down Expand Up @@ -60,3 +61,90 @@ def test_imshow():
im = ax.images[0]
assert np.allclose(im._array, data)
assert im.get_extent() == [0, 10, 0, 5]


def test_set_xscale_log():
_, ax = plt.subplots()
x = np.arange(50.0)
y = np.sin(0.2 * x)

ax.plot(x, y, lw=2)
ax.set_xscale('log')

assert ax.get_xscale() == 'log'


def test_set_yscale_log():
_, ax = plt.subplots()
x = np.arange(50.0)
y = np.sin(0.2 * x)

ax.plot(x, y, lw=2)
ax.set_yscale('log')

assert ax.get_yscale() == 'log'


def test_set_xscale_invalid():
_, ax = plt.subplots()
with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"):
ax.set_xscale('invalid_scale')


def test_set_yscale_invalid():
_, ax = plt.subplots()
with pytest.raises(ValueError, match="Scale must be 'linear' or 'log'"):
ax.set_yscale('invalid_scale')


def test_set_xscale_log_before_plot():
_, ax = plt.subplots()
x = np.arange(50.0)
y = np.sin(0.2 * x)

ax.set_xscale('log')
ax.plot(x, y, lw=2)

assert ax.get_xscale() == 'log'


def test_set_yscale_log_before_plot():
_, ax = plt.subplots()
x = np.arange(50.0)
y = np.sin(0.2 * x)

ax.set_yscale('log')
ax.plot(x, y, lw=2)

assert ax.get_yscale() == 'log'


def test_semilogx():
_, ax = plt.subplots()
x = np.arange(1.0, 50.0)
y = np.sin(0.2 * x)

ax.semilogx(x, y, lw=2)

assert ax.get_xscale() == 'log'


def test_semilogy():
_, ax = plt.subplots()
x = np.arange(50.0)
y = np.exp(0.1 * x)

ax.semilogy(x, y, lw=2)

assert ax.get_yscale() == 'log'


def test_loglog():
_, ax = plt.subplots()
x = np.arange(1.0, 50.0)
y = np.exp(0.1 * x)

ax.loglog(x, y, lw=2)

assert ax.get_xscale() == 'log'
assert ax.get_yscale() == 'log'
Loading