Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: fix FRB integration with CylindricalResolutionBuffer #4790

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions yt/data_objects/tests/test_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,12 @@ def test_to_frb(self):
data_source=dd,
)
frb = proj.to_frb((1.0, "unitary"), 64)
assert_equal(frb.radius, (1.0, "unitary"))
assert_equal(frb.buff_size, 64)
assert_equal(frb.radius, ds.quan(1.0, "unitary"))
assert_equal(frb.buff_size, (64, 64))

# exercise field access
# see https://github.com/yt-project/yt/issues/4789
frb["r"]

def test_extract_isocontours(self):
# Test isocontour properties for AMRGridData
Expand Down
55 changes: 45 additions & 10 deletions yt/visualization/fixed_resolution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
import weakref
from functools import partial
from typing import TYPE_CHECKING, Optional
from numbers import Real
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
from unyt import Unit, unyt_quantity

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import FieldKey, MaskT
from yt._typing import FieldKey, MaskT, Quantity
from yt.data_objects.image_array import ImageArray
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, iter_fields, mylog
Expand Down Expand Up @@ -581,18 +583,51 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer):
that supports non-aligned input data objects, primarily cutting planes.
"""

def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None):
def __init__(
self,
data_source,
radius: Quantity,
buff_size: Union[int, tuple[int, int]],
antialias: bool = True,
*,
filters: Optional[list["FixedResolutionBufferFilter"]] = None,
):
if data_source.ds is None:
raise TypeError(
"CylindricalFixedResolutionBuffer requires an actual dataset "
"be attached to data_source (got None)"
)
self.data_source = data_source
self.ds = data_source.ds
self.radius = radius
self.buff_size = buff_size
self.radius = self._sanitize_radius(radius)
if isinstance(buff_size, int):
self.buff_size = (buff_size, buff_size)
else:
self.buff_size = buff_size
self.bounds = self._get_bounds()
self.antialias = antialias
self.data = {}
self.data: dict[str, ImageArray] = {}
self.mask: dict[str, MaskT] = {}
self._filters = filters if filters is not None else []
self.ds.plots.append(weakref.proxy(self))

def _sanitize_radius(self, r: Quantity) -> unyt_quantity:
if isinstance(r, unyt_quantity):
return self.ds.quan(r).to("code_length")
elif (
isinstance(r, tuple)
and len(r) == 2
and (isinstance(r[0], Real) and isinstance(r[1], (str, Unit)))
):
return self.ds.quan(*r).to("code_length")
else:
raise TypeError(
f"Got unparsable radius value {r!r}, expected a unyt_quantity-like)"
)

ds = getattr(data_source, "ds", None)
if ds is not None:
ds.plots.append(weakref.proxy(self))
def _get_bounds(self) -> tuple[float, float, float, float]:
dx = dy = self.radius.item()
return (0.0, dx, 0.0, dy)

@override
def _generate_image_and_mask(self, item) -> None:
Expand All @@ -604,7 +639,7 @@ def _generate_image_and_mask(self, item) -> None:
self.data_source["theta"],
self.data_source["dtheta"],
self.data_source[item].astype("float64"),
self.radius,
extents=self.bounds,
return_mask=True,
)
self.data[item] = ImageArray(
Expand Down