Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Add threshold config for quiver plots #4742

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 53 additions & 1 deletion yt/visualization/plot_modifications.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from unyt import unyt_quantity

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import AnyFieldKey, FieldKey
from yt._typing import AnyFieldKey, FieldKey, Quantity
from yt.data_objects.data_containers import YTDataContainer
from yt.data_objects.level_sets.clump_handling import Clump
from yt.data_objects.selection_objects.cut_region import YTCutRegion
Expand Down Expand Up @@ -444,6 +444,9 @@ def __init__(
scale_units=None,
normalize=False,
plot_args=None,
threshold_field: Optional[FieldKey] = None,
lower_threshold: Optional[Quantity] = None,
upper_threshold: Optional[Quantity] = None,
**kwargs,
):
self.factor = _validate_factor_tuple(factor)
Expand All @@ -463,6 +466,13 @@ def __init__(

self.plot_args = plot_args

if upper_threshold is None and lower_threshold is None:
# Need to specify at least one value
threshold_field = None
self.threshold_field = threshold_field
self.lower_threshold = lower_threshold
self.upper_threshold = upper_threshold

def __call__(self, plot) -> "BaseQuiverCallback":
ftype = plot.data._current_fluid_type
# Instantiation of these is cheap
Expand Down Expand Up @@ -553,6 +563,9 @@ def __call__(self, plot) -> "BaseQuiverCallback":
normalize=self.normalize,
bv_x=bv_x,
bv_y=bv_y,
threshold_field=self.threshold_field,
lower_threshold=self.lower_threshold,
upper_threshold=self.upper_threshold,
**self.plot_args,
)
return qcb(plot)
Expand Down Expand Up @@ -700,6 +713,9 @@ def __init__(
scale_units=None,
normalize=False,
plot_args=None,
threshold_field: Optional[FieldKey] = None,
lower_threshold: Optional[Quantity] = None,
upper_threshold: Optional[Quantity] = None,
**kwargs,
):
self.field_x = field_x
Expand All @@ -721,6 +737,12 @@ def __init__(
plot_args.update(kwargs)

self.plot_args = plot_args
if upper_threshold is None and lower_threshold is None:
# Need to specify at least one value
threshold_field = None
self.threshold_field = threshold_field
self.lower_threshold = lower_threshold
self.upper_threshold = upper_threshold

@abstractmethod
def _get_quiver_data(self, plot, bounds: tuple, nx: int, ny: int):
Expand Down Expand Up @@ -807,6 +829,9 @@ def __init__(
bv_x=0,
bv_y=0,
plot_args=None,
threshold_field: Optional[FieldKey] = None,
lower_threshold: Optional[Quantity] = None,
upper_threshold: Optional[Quantity] = None,
**kwargs,
):
super().__init__(
Expand All @@ -818,6 +843,9 @@ def __init__(
scale_units=scale_units,
normalize=normalize,
plot_args=plot_args,
threshold_field=threshold_field,
lower_threshold=lower_threshold,
upper_threshold=upper_threshold,
**kwargs,
)
self.bv_x = bv_x
Expand Down Expand Up @@ -869,6 +897,25 @@ def _transformed_field(field, data):
periodic,
)

if self.threshold_field is not None:
pixT = plot.data.ds.coordinates.pixelize(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more question :) Do you think it's worth catching the case when threshold_field is the same as field_x or field_y to avoid this pixelize call?

plot.data.axis,
plot.data,
self.threshold_field,
bounds,
(nx, ny),
False, # antialias
periodic,
)
mask = np.ones(pixT.shape, dtype="bool")
if self.lower_threshold is not None:
np.logical_and(mask, pixT > self.lower_threshold, mask)
if self.upper_threshold is not None:
np.logical_and(mask, pixT < self.upper_threshold, mask)
mask = ~mask
else:
mask = None

if self.field_c is not None:
pixC = plot.data.ds.coordinates.pixelize(
plot.data.axis,
Expand All @@ -882,6 +929,11 @@ def _transformed_field(field, data):
else:
pixC = None

if mask is not None:
pixX[mask] = np.nan
pixY[mask] = np.nan
if pixC:
pixC[mask] = np.nan
return pixX, pixY, pixC


Expand Down