Skip to content

Commit

Permalink
artists.scatter: plot unstructured data (#1050)
Browse files Browse the repository at this point in the history
* working example

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* scatter_heatmap -> scatter

* parse_limits, handle broadcasting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* skip tests

* cleanup tests

* cleanup docs

* changelog

* refine scatter parameters, docstring

* Update _base.py

* ksunden comments

random seed, correcting arg/kwarg parsing

* Update _base.py

scatter: replace full arrays with sufficiently full arrays

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rst link

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Kyle Sunden <git@ksunden.space>
  • Loading branch information
3 people committed Apr 16, 2022
1 parent e7e5459 commit 6d87f9d
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/).
## [Unreleased]

### Added
- `artists.Axes.scatter`: plot one variable against another, with scatter point color determined by a channel.
- Invalid `unit` conversions now throw a `pint` error.
- `data.from_LabRAM`: import Horiba LabRAM txt files

Expand Down
95 changes: 95 additions & 0 deletions WrightTools/artists/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import matplotlib
from matplotlib.projections import register_projection
from matplotlib.colors import Normalize
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

Expand Down Expand Up @@ -463,6 +464,100 @@ def imshow(self, *args, **kwargs):
super().invert_yaxis()
return out

def scatter(self, *args, **kwargs):
"""Scatter plot a channel against two _variables_.
Scatter point color reflects channel values.
Data need not be structured.
If data object is not provided, scatter reverts to the `matplotlib parent method <https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.scatter.html>`_.
args
---------
data : 2D WrightTools.data.Data object
Data to plot.
kwargs
----------
x : int or string (optional)
axis name or index for x (abscissa) axis. Default is 0.
If x does not match an axis, searches variable names for match.
y : int or string (optional)
axis name or index for y (ordinate) axis. Default is 1.
If y does not match an axis, searches variable names for match.
channel : int or string (optional)
Channel index or name. Default is 0.
autolabel : {'none', 'both', 'x', 'y'} (optional)
Parameterize application of labels directly from data object. Default is none.
xlabel : string (optional)
xlabel. Default is None.
ylabel : string (optional)
ylabel. Default is None.
**kwargs
matplotlib.axes.Axes.scatter__ optional keyword arguments.
__ https://matplotlib.org/api/_as_gen/matplotlib.pyplot.scatter.html
Returns
-------
matplotlib.collections.PathCollection
"""
args = list(args)
if isinstance(args[0], Data):
data = args.pop(0)

coords = []
for axis in [kwargs.pop("x", 0), kwargs.pop("y", 1)]:
try: # check axes
axis = wt_kit.get_index(data.axis_names, axis)
axis = data.axes[axis][:]
except (ValueError, IndexError): # check vars
axis = wt_kit.get_index(data.variable_names, axis)
axis = data.variables[axis][:]
# broadcast up to channel shape
coords.append(axis)

if "c" in kwargs.keys():
raise KeyError(
"'c' kwarg not allowed when data object provided. \
Use `cmap` instead to control colors."
)

channel = kwargs.pop("channel", 0)
channel_index = wt_kit.get_index(data.channel_names, channel)

limits = {}
limits = self._parse_limits(data=data, channel_index=channel_index, **limits)
norm = Normalize(**limits)

cmap = self._parse_cmap(data, channel_index=channel_index, **kwargs)["cmap"]

z = data.channels[channel_index][:]

# fill x, y, z to joint shape
shape = wt_kit.joint_shape(z, *coords)

def full(arr, shape):
for i in range(arr.ndim):
if arr.shape[i] == 1:
arr = np.repeat(arr, shape[i], axis=i)
return arr

args = [full(ax, shape).flatten() for ax in coords] + args

z = full(z, shape).flatten()
z = norm(z)
z = cmap(z)
kwargs["c"] = z

self._apply_labels(
autolabel=kwargs.pop("autolabel", False),
xlabel=kwargs.pop("xlabel", None),
ylabel=kwargs.pop("ylabel", None),
data=data,
channel_index=channel_index,
)

return super().scatter(*args, **kwargs)

def pcolormesh(self, *args, **kwargs):
"""Create a pseudocolor plot of a 2-D array.
Expand Down
13 changes: 11 additions & 2 deletions docs/artists.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,21 +238,30 @@ Plot
^^^^

Once you have axes with the :meth:`~matplotlib.pyplot.subplot` call, it can be used as you are used to using :class:`matplotlib.axes.Axes` objects (though some defaults, such as colormap, differ from bare matplotlib).
However, you can also pass :class:`WrightTools.data.Data` objects in directly (and there are some kwargs available when you do).
For certain plot methods, you can also pass :class:`WrightTools.data.Data` objects in directly (and there are some extra kwargs available when you do).
These :class:`WrightTools.artists.Axes` will extract out the proper arrays and plot the data.

.. code-block:: python
for indx, wigner, color in zip(indxs, wigners, wigner_colors):
ax = plt.subplot(gs[indx])
ax.pcolor(wigner, vmin=0, vmax=1) # global colormpa
ax.pcolor(wigner, vmin=0, vmax=1) # global colormap
ax.contour(wigner) # local contours
...
for indx, color, traces in zip(indxs, trace_colors, tracess):
ax = plt.subplot(gs[indx])
for trace, w_color in zip(traces, wigner_colors):
ax.plot(trace, color=w_color, linewidth=1.5)
Currently supported plot methods are:
- :meth:`~WrightTools.artists.Axes.contour`
- :meth:`~WrightTools.artists.Axes.contourf`
- :meth:`~WrightTools.artists.Axes.imshow`
- :meth:`~WrightTools.artists.Axes.pcolor`
- :meth:`~WrightTools.artists.Axes.pcolormesh`
- :meth:`~WrightTools.artists.Axes.plot`
- :meth:`~WrightTools.artists.Axes.scatter`

Beautify
^^^^^^^^

Expand Down
36 changes: 36 additions & 0 deletions tests/artists/test_scatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import numpy as np
import WrightTools as wt
import matplotlib.pyplot as plt


def test_scatter_broadcast_vars():
rng = np.random.default_rng(seed=42)
data = wt.data.Data(name="data")
a = rng.random((50, 1))
b = rng.random((1, 50))
data.create_variable("x", values=2 * a - 1)
data.create_variable("y", values=2 * b - 1)
data.create_channel("z", values=np.exp(-(data.x[:] ** 2 + data.y[:] ** 2)))
fig, gs = wt.artists.create_figure()
ax = plt.subplot(gs[0])
ax.scatter(data, x="x", y="y", channel="z", s=20, alpha=0.5)
data.close()


def test_scatter_signed_channel():
rng = np.random.default_rng(seed=42)
data = wt.data.Data(name="data")
a = rng.random((10**3))
b = rng.random((10**3))
data.create_variable("x", values=4 * np.pi * (a - 0.5))
data.create_variable("y", values=4 * np.pi * (b - 0.5))
data.create_channel("z", values=np.sin((data.x[:] ** 2 + data.y[:] ** 2) ** 0.5), signed=True)
fig, gs = wt.artists.create_figure()
ax = plt.subplot(gs[0])
ax.scatter(data, x="x", y="y", channel="z", s=20, alpha=0.5)
data.close()


if __name__ == "__main__":
test_scatter_broadcast_vars()
test_scatter_signed_channel()

0 comments on commit 6d87f9d

Please sign in to comment.