diff --git a/yt/data_objects/tests/test_data_containers.py b/yt/data_objects/tests/test_data_containers.py index d0f99c6a79..8c3455fd10 100644 --- a/yt/data_objects/tests/test_data_containers.py +++ b/yt/data_objects/tests/test_data_containers.py @@ -126,8 +126,12 @@ def test_to_frb(self): data_source=dd, ) frb = proj.to_frb((1.0, "unitary"), 64) - assert_equal(frb.radius, (1.0, "unitary")) - assert_equal(frb.buff_size, 64) + assert_equal(frb.radius, ds.quan(1.0, "unitary")) + assert_equal(frb.buff_size, (64, 64)) + + # exercise field access + # see https://github.com/yt-project/yt/issues/4789 + frb["r"] def test_extract_isocontours(self): # Test isocontour properties for AMRGridData diff --git a/yt/visualization/fixed_resolution.py b/yt/visualization/fixed_resolution.py index 7e1123cdf6..04d869a06b 100644 --- a/yt/visualization/fixed_resolution.py +++ b/yt/visualization/fixed_resolution.py @@ -1,12 +1,14 @@ import sys import weakref from functools import partial -from typing import TYPE_CHECKING, Optional +from numbers import Real +from typing import TYPE_CHECKING, Optional, Union import numpy as np +from unyt import Unit, unyt_quantity from yt._maintenance.deprecation import issue_deprecation_warning -from yt._typing import FieldKey, MaskT +from yt._typing import FieldKey, MaskT, Quantity from yt.data_objects.image_array import ImageArray from yt.frontends.ytdata.utilities import save_as_dataset from yt.funcs import get_output_filename, iter_fields, mylog @@ -581,18 +583,51 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer): that supports non-aligned input data objects, primarily cutting planes. """ - def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None): + def __init__( + self, + data_source, + radius: Quantity, + buff_size: Union[int, tuple[int, int]], + antialias: bool = True, + *, + filters: Optional[list["FixedResolutionBufferFilter"]] = None, + ): + if data_source.ds is None: + raise TypeError( + "CylindricalFixedResolutionBuffer requires an actual dataset " + "be attached to data_source (got None)" + ) self.data_source = data_source self.ds = data_source.ds - self.radius = radius - self.buff_size = buff_size + self.radius = self._sanitize_radius(radius) + if isinstance(buff_size, int): + self.buff_size = (buff_size, buff_size) + else: + self.buff_size = buff_size + self.bounds = self._get_bounds() self.antialias = antialias - self.data = {} + self.data: dict[str, ImageArray] = {} + self.mask: dict[str, MaskT] = {} self._filters = filters if filters is not None else [] + self.ds.plots.append(weakref.proxy(self)) + + def _sanitize_radius(self, r: Quantity) -> unyt_quantity: + if isinstance(r, unyt_quantity): + return self.ds.quan(r).to("code_length") + elif ( + isinstance(r, tuple) + and len(r) == 2 + and (isinstance(r[0], Real) and isinstance(r[1], (str, Unit))) + ): + return self.ds.quan(*r).to("code_length") + else: + raise TypeError( + f"Got unparsable radius value {r!r}, expected a unyt_quantity-like)" + ) - ds = getattr(data_source, "ds", None) - if ds is not None: - ds.plots.append(weakref.proxy(self)) + def _get_bounds(self) -> tuple[float, float, float, float]: + dx = dy = self.radius.item() + return (0.0, dx, 0.0, dy) @override def _generate_image_and_mask(self, item) -> None: @@ -604,7 +639,7 @@ def _generate_image_and_mask(self, item) -> None: self.data_source["theta"], self.data_source["dtheta"], self.data_source[item].astype("float64"), - self.radius, + extents=self.bounds, return_mask=True, ) self.data[item] = ImageArray(