diff --git a/src/matplotgl/axes.py b/src/matplotgl/axes.py index a3e1479..2949f31 100644 --- a/src/matplotgl/axes.py +++ b/src/matplotgl/axes.py @@ -345,7 +345,7 @@ def autoscale(self): def add_artist(self, artist): self._artists.append(artist) - self.scene.add(artist.get()) + self.scene.add(artist._as_object3d()) def get_figure(self): return self._fig @@ -717,7 +717,9 @@ def loglog(self, *args, **kwargs): def scatter(self, *args, c=None, **kwargs): if c is None: c = f"C{len(self.collections)}" - coll = Points(*args, c=c, **kwargs) + coll = Points( + *args, c=c, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs + ) coll.axes = self self.collections.append(coll) self.add_artist(coll) @@ -733,7 +735,7 @@ def imshow(self, *args, **kwargs): return image def pcolormesh(self, *args, **kwargs): - mesh = Mesh(*args, **kwargs) + mesh = Mesh(*args, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs) mesh.axes = self self.collections.append(mesh) self.add_artist(mesh) diff --git a/src/matplotgl/image.py b/src/matplotgl/image.py index 3d28d7a..5845e3c 100644 --- a/src/matplotgl/image.py +++ b/src/matplotgl/image.py @@ -63,7 +63,7 @@ def get_bbox(self) -> dict[str, float]: def _update_colors(self) -> None: self._texture.data = self._make_colors() - def get(self) -> p3.Object3D: + def _as_object3d(self) -> p3.Object3D: return self._image def _set_xscale(self, scale: str) -> None: diff --git a/src/matplotgl/line.py b/src/matplotgl/line.py index 34efe7d..0cd0c18 100644 --- a/src/matplotgl/line.py +++ b/src/matplotgl/line.py @@ -61,7 +61,7 @@ def get_bbox(self): bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad)) return {"left": left, "right": right, "bottom": bottom, "top": top} - def get(self): + def _as_object3d(self) -> p3.Object3D: out = [] if self._line is not None: out.append(self._line) diff --git a/src/matplotgl/mesh.py b/src/matplotgl/mesh.py index 65e4d4f..5081d50 100644 --- a/src/matplotgl/mesh.py +++ b/src/matplotgl/mesh.py @@ -12,7 +12,14 @@ class Mesh: - def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"): + def __init__( + self, + *args, + cmap: str = "viridis", + norm: str = "linear", + xscale="linear", + yscale="linear", + ): if len(args) not in (1, 3): raise ValueError( f"Invalid number of arguments: expected 1 or 3. Got {len(args)}" @@ -28,8 +35,8 @@ def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"): self.axes = None self._colorbar = None - self._xscale = "linear" - self._yscale = "linear" + self._xscale = xscale + self._yscale = yscale self._x = np.asarray(x) self._y = np.asarray(y) @@ -126,7 +133,7 @@ def get_bbox(self) -> dict[str, float]: bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad)) return {"left": left, "right": right, "bottom": bottom, "top": top} - def get(self) -> p3.Object3D: + def _as_object3d(self) -> p3.Object3D: return self._mesh def get_xdata(self) -> np.ndarray: diff --git a/src/matplotgl/points.py b/src/matplotgl/points.py index b634794..72039c8 100644 --- a/src/matplotgl/points.py +++ b/src/matplotgl/points.py @@ -10,7 +10,22 @@ from .norm import Normalizer from .utils import find_limits, fix_empty_range -SHADER_LIBRARY = { +# Custom vertex shader for variable size and color +VERTEX_SHADER = """ +attribute float size; +attribute vec3 customColor; +varying vec3 vColor; + +void main() { + vColor = customColor; + vec4 mvPosition = modelViewMatrix * vec4(position, 1.0); + gl_PointSize = size; + gl_Position = projectionMatrix * mvPosition; +} +""" + +# Custom fragment shaders for different markers +FRAGMENT_SHADERS = { "o": """ varying vec3 vColor; @@ -62,29 +77,31 @@ def __init__( zorder=0, cmap="viridis", norm: str = "linear", + xscale="linear", + yscale="linear", ) -> None: self.axes = None self._x = np.asarray(x) self._y = np.asarray(y) - self._xscale = "linear" - self._yscale = "linear" + self._xscale = xscale + self._yscale = yscale self._zorder = zorder + self._geometry = p3.BufferGeometry( + attributes={"position": p3.BufferAttribute(array=self._make_positions())} + ) + if not isinstance(c, str) or not np.isscalar(s) or marker != "s": if isinstance(c, str): self._c = np.ones_like(self._x) self._norm = Normalizer(vmin=1, vmax=1) self._cmap = cm.LinearSegmentedColormap.from_list("tmp", [c, c]) - # ( - # np.ones_like(self._x) - # ) else: self._c = np.asarray(c) self._norm = Normalizer( vmin=np.min(self._c), vmax=np.max(self._c), norm=norm ) self._cmap = mpl.colormaps[cmap].copy() - # rgba = self.cmap(self.norm(self._c)) colors = self._make_colors() @@ -93,75 +110,45 @@ def __init__( else: sizes = np.asarray(s, dtype=np.float32) - # Custom vertex shader for variable size and color - vertex_shader = """ -attribute float size; -attribute vec3 customColor; -varying vec3 vColor; - -void main() { - vColor = customColor; - vec4 mvPosition = modelViewMatrix * vec4(position, 1.0); - gl_PointSize = size; - gl_Position = projectionMatrix * mvPosition; -} -""" - - self._geometry = p3.BufferGeometry( - attributes={ - "position": p3.BufferAttribute( - array=np.array( - [self._x, self._y, np.full_like(self._x, self._zorder)], - dtype="float32", - ).T - ), + self._geometry.attributes.update( + { "customColor": p3.BufferAttribute(array=colors), "size": p3.BufferAttribute(array=sizes), } ) # Create ShaderMaterial with custom shaders self._material = p3.ShaderMaterial( - vertexShader=vertex_shader, - fragmentShader=SHADER_LIBRARY[marker], + vertexShader=VERTEX_SHADER, + fragmentShader=FRAGMENT_SHADERS[marker], transparent=True, ) else: - self._geometry = p3.BufferGeometry( - attributes={ - "position": p3.BufferAttribute( - array=np.array( - [self._x, self._y, np.full_like(self._x, self._zorder)], - dtype="float32", - ).T - ), - } - ) - self._material = p3.PointsMaterial(color=cm.to_hex(c), size=s) self._points = p3.Points(geometry=self._geometry, material=self._material) + def _make_positions(self) -> np.ndarray: + with warnings.catch_warnings(category=RuntimeWarning, action="ignore"): + xx = self._x if self._xscale == "linear" else np.log10(self._x) + yy = self._y if self._yscale == "linear" else np.log10(self._y) + return np.array([xx, yy, np.full_like(xx, self._zorder)], dtype="float32").T + def _make_colors(self) -> np.ndarray: return self._cmap(self.norm(self._c))[..., :3].astype("float32") def _update_colors(self) -> None: self._geometry.attributes["customColor"].array = self._make_colors() + def _update_positions(self): + self._geometry.attributes["position"].array = self._make_positions() + def get_bbox(self): pad = 0.03 left, right = fix_empty_range(find_limits(self._x, scale=self._xscale, pad=pad)) bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad)) return {"left": left, "right": right, "bottom": bottom, "top": top} - def _update(self): - with warnings.catch_warnings(category=RuntimeWarning, action="ignore"): - xx = self._x if self._xscale == "linear" else np.log10(self._x) - yy = self._y if self._yscale == "linear" else np.log10(self._y) - self._geometry.attributes["position"].array = np.array( - [xx, yy, np.full_like(xx, self._zorder)], dtype="float32" - ).T - - def get(self): + def _as_object3d(self) -> p3.Object3D: return self._points def get_xdata(self) -> np.ndarray: @@ -169,27 +156,31 @@ def get_xdata(self) -> np.ndarray: def set_xdata(self, x): self._x = np.asarray(x) - self._update() + self._update_positions() def get_ydata(self) -> np.ndarray: return self._y def set_ydata(self, y): self._y = np.asarray(y) - self._update() + self._update_positions() def set_data(self, xy): self._x = np.asarray(xy[:, 0]) self._y = np.asarray(xy[:, 1]) - self._update() + self._update_positions() def _set_xscale(self, scale): self._xscale = scale - self._update() + self._update_positions() def _set_yscale(self, scale): self._yscale = scale - self._update() + self._update_positions() + + def set_array(self, c: np.ndarray): + self._c = np.asarray(c) + self._update_colors() def set_cmap(self, cmap: str) -> None: self._cmap = mpl.colormaps[cmap].copy()