Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/matplotgl/axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def autoscale(self):

def add_artist(self, artist):
self._artists.append(artist)
self.scene.add(artist.get())
self.scene.add(artist._as_object3d())

def get_figure(self):
return self._fig
Expand Down Expand Up @@ -717,7 +717,9 @@ def loglog(self, *args, **kwargs):
def scatter(self, *args, c=None, **kwargs):
if c is None:
c = f"C{len(self.collections)}"
coll = Points(*args, c=c, **kwargs)
coll = Points(
*args, c=c, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs
)
coll.axes = self
self.collections.append(coll)
self.add_artist(coll)
Expand All @@ -733,7 +735,7 @@ def imshow(self, *args, **kwargs):
return image

def pcolormesh(self, *args, **kwargs):
mesh = Mesh(*args, **kwargs)
mesh = Mesh(*args, xscale=self.get_xscale(), yscale=self.get_yscale(), **kwargs)
mesh.axes = self
self.collections.append(mesh)
self.add_artist(mesh)
Expand Down
2 changes: 1 addition & 1 deletion src/matplotgl/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_bbox(self) -> dict[str, float]:
def _update_colors(self) -> None:
self._texture.data = self._make_colors()

def get(self) -> p3.Object3D:
def _as_object3d(self) -> p3.Object3D:
return self._image

def _set_xscale(self, scale: str) -> None:
Expand Down
2 changes: 1 addition & 1 deletion src/matplotgl/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_bbox(self):
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
return {"left": left, "right": right, "bottom": bottom, "top": top}

def get(self):
def _as_object3d(self) -> p3.Object3D:
out = []
if self._line is not None:
out.append(self._line)
Expand Down
15 changes: 11 additions & 4 deletions src/matplotgl/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@


class Mesh:
def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"):
def __init__(
self,
*args,
cmap: str = "viridis",
norm: str = "linear",
xscale="linear",
yscale="linear",
):
if len(args) not in (1, 3):
raise ValueError(
f"Invalid number of arguments: expected 1 or 3. Got {len(args)}"
Expand All @@ -28,8 +35,8 @@ def __init__(self, *args, cmap: str = "viridis", norm: str = "linear"):

self.axes = None
self._colorbar = None
self._xscale = "linear"
self._yscale = "linear"
self._xscale = xscale
self._yscale = yscale

self._x = np.asarray(x)
self._y = np.asarray(y)
Expand Down Expand Up @@ -126,7 +133,7 @@ def get_bbox(self) -> dict[str, float]:
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
return {"left": left, "right": right, "bottom": bottom, "top": top}

def get(self) -> p3.Object3D:
def _as_object3d(self) -> p3.Object3D:
return self._mesh

def get_xdata(self) -> np.ndarray:
Expand Down
103 changes: 47 additions & 56 deletions src/matplotgl/points.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,22 @@
from .norm import Normalizer
from .utils import find_limits, fix_empty_range

SHADER_LIBRARY = {
# Custom vertex shader for variable size and color
VERTEX_SHADER = """
attribute float size;
attribute vec3 customColor;
varying vec3 vColor;

void main() {
vColor = customColor;
vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
gl_PointSize = size;
gl_Position = projectionMatrix * mvPosition;
}
"""

# Custom fragment shaders for different markers
FRAGMENT_SHADERS = {
"o": """
varying vec3 vColor;

Expand Down Expand Up @@ -62,29 +77,31 @@ def __init__(
zorder=0,
cmap="viridis",
norm: str = "linear",
xscale="linear",
yscale="linear",
) -> None:
self.axes = None
self._x = np.asarray(x)
self._y = np.asarray(y)
self._xscale = "linear"
self._yscale = "linear"
self._xscale = xscale
self._yscale = yscale
self._zorder = zorder

self._geometry = p3.BufferGeometry(
attributes={"position": p3.BufferAttribute(array=self._make_positions())}
)

if not isinstance(c, str) or not np.isscalar(s) or marker != "s":
if isinstance(c, str):
self._c = np.ones_like(self._x)
self._norm = Normalizer(vmin=1, vmax=1)
self._cmap = cm.LinearSegmentedColormap.from_list("tmp", [c, c])
# (
# np.ones_like(self._x)
# )
else:
self._c = np.asarray(c)
self._norm = Normalizer(
vmin=np.min(self._c), vmax=np.max(self._c), norm=norm
)
self._cmap = mpl.colormaps[cmap].copy()
# rgba = self.cmap(self.norm(self._c))

colors = self._make_colors()

Expand All @@ -93,103 +110,77 @@ def __init__(
else:
sizes = np.asarray(s, dtype=np.float32)

# Custom vertex shader for variable size and color
vertex_shader = """
attribute float size;
attribute vec3 customColor;
varying vec3 vColor;

void main() {
vColor = customColor;
vec4 mvPosition = modelViewMatrix * vec4(position, 1.0);
gl_PointSize = size;
gl_Position = projectionMatrix * mvPosition;
}
"""

self._geometry = p3.BufferGeometry(
attributes={
"position": p3.BufferAttribute(
array=np.array(
[self._x, self._y, np.full_like(self._x, self._zorder)],
dtype="float32",
).T
),
self._geometry.attributes.update(
{
"customColor": p3.BufferAttribute(array=colors),
"size": p3.BufferAttribute(array=sizes),
}
)
# Create ShaderMaterial with custom shaders
self._material = p3.ShaderMaterial(
vertexShader=vertex_shader,
fragmentShader=SHADER_LIBRARY[marker],
vertexShader=VERTEX_SHADER,
fragmentShader=FRAGMENT_SHADERS[marker],
transparent=True,
)
else:
self._geometry = p3.BufferGeometry(
attributes={
"position": p3.BufferAttribute(
array=np.array(
[self._x, self._y, np.full_like(self._x, self._zorder)],
dtype="float32",
).T
),
}
)

self._material = p3.PointsMaterial(color=cm.to_hex(c), size=s)

self._points = p3.Points(geometry=self._geometry, material=self._material)

def _make_positions(self) -> np.ndarray:
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
xx = self._x if self._xscale == "linear" else np.log10(self._x)
yy = self._y if self._yscale == "linear" else np.log10(self._y)
return np.array([xx, yy, np.full_like(xx, self._zorder)], dtype="float32").T

def _make_colors(self) -> np.ndarray:
return self._cmap(self.norm(self._c))[..., :3].astype("float32")

def _update_colors(self) -> None:
self._geometry.attributes["customColor"].array = self._make_colors()

def _update_positions(self):
self._geometry.attributes["position"].array = self._make_positions()

def get_bbox(self):
pad = 0.03
left, right = fix_empty_range(find_limits(self._x, scale=self._xscale, pad=pad))
bottom, top = fix_empty_range(find_limits(self._y, scale=self._yscale, pad=pad))
return {"left": left, "right": right, "bottom": bottom, "top": top}

def _update(self):
with warnings.catch_warnings(category=RuntimeWarning, action="ignore"):
xx = self._x if self._xscale == "linear" else np.log10(self._x)
yy = self._y if self._yscale == "linear" else np.log10(self._y)
self._geometry.attributes["position"].array = np.array(
[xx, yy, np.full_like(xx, self._zorder)], dtype="float32"
).T

def get(self):
def _as_object3d(self) -> p3.Object3D:
return self._points

def get_xdata(self) -> np.ndarray:
return self._x

def set_xdata(self, x):
self._x = np.asarray(x)
self._update()
self._update_positions()

def get_ydata(self) -> np.ndarray:
return self._y

def set_ydata(self, y):
self._y = np.asarray(y)
self._update()
self._update_positions()

def set_data(self, xy):
self._x = np.asarray(xy[:, 0])
self._y = np.asarray(xy[:, 1])
self._update()
self._update_positions()

def _set_xscale(self, scale):
self._xscale = scale
self._update()
self._update_positions()

def _set_yscale(self, scale):
self._yscale = scale
self._update()
self._update_positions()

def set_array(self, c: np.ndarray):
self._c = np.asarray(c)
self._update_colors()

def set_cmap(self, cmap: str) -> None:
self._cmap = mpl.colormaps[cmap].copy()
Expand Down
Loading