Skip to content

Commit

Permalink
BUG: fix FRB integration with CylindricalResolutionBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
neutrinoceros committed Jan 27, 2024
1 parent d5b3335 commit e9d0b3e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 12 deletions.
8 changes: 6 additions & 2 deletions yt/data_objects/tests/test_data_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,12 @@ def test_to_frb(self):
data_source=dd,
)
frb = proj.to_frb((1.0, "unitary"), 64)
assert_equal(frb.radius, (1.0, "unitary"))
assert_equal(frb.buff_size, 64)
assert_equal(frb.radius, ds.quan(1.0, "unitary"))
assert_equal(frb.buff_size, (64, 64))

# exercise field access
# see https://github.com/yt-project/yt/issues/4789
frb["r"]

def test_extract_isocontours(self):
# Test isocontour properties for AMRGridData
Expand Down
55 changes: 45 additions & 10 deletions yt/visualization/fixed_resolution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import sys
import weakref
from functools import partial
from typing import TYPE_CHECKING, Optional
from numbers import Real
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
from unyt import Unit, unyt_quantity

from yt._maintenance.deprecation import issue_deprecation_warning
from yt._typing import FieldKey, MaskT
from yt._typing import FieldKey, MaskT, Quantity
from yt.data_objects.image_array import ImageArray
from yt.frontends.ytdata.utilities import save_as_dataset
from yt.funcs import get_output_filename, iter_fields, mylog
Expand Down Expand Up @@ -581,18 +583,51 @@ class CylindricalFixedResolutionBuffer(FixedResolutionBuffer):
that supports non-aligned input data objects, primarily cutting planes.
"""

def __init__(self, data_source, radius, buff_size, antialias=True, *, filters=None):
def __init__(
self,
data_source,
radius: Quantity,
buff_size: Union[int, tuple[int, int]],
antialias: bool = True,
*,
filters: Optional[list["FixedResolutionBufferFilter"]] = None,
):
if data_source.ds is None:
raise TypeError(
"CylindricalFixedResolutionBuffer requires an actual dataset "
"be attached to data_source (got None)"
)
self.data_source = data_source
self.ds = data_source.ds
self.radius = radius
self.buff_size = buff_size
self.radius = self._sanitize_radius(radius)
if isinstance(buff_size, int):
self.buff_size = (buff_size, buff_size)
else:
self.buff_size = buff_size
self.bounds = self._get_bounds()
self.antialias = antialias
self.data = {}
self.data: dict[str, ImageArray] = {}
self.mask: dict[str, MaskT] = {}
self._filters = filters if filters is not None else []
self.ds.plots.append(weakref.proxy(self))

def _sanitize_radius(self, r: Quantity) -> unyt_quantity:
if isinstance(r, unyt_quantity):
return self.ds.quan(r).to("code_length")
elif (
isinstance(r, tuple)
and len(r) == 2
and (isinstance(r[0], Real) and isinstance(r[1], (str, Unit)))
):
return self.ds.quan(*r).to("code_length")
else:
raise TypeError(
f"Got unparsable radius value {r!r}, expected a unyt_quantity-like)"
)

ds = getattr(data_source, "ds", None)
if ds is not None:
ds.plots.append(weakref.proxy(self))
def _get_bounds(self) -> tuple[float, float, float, float]:
dx = dy = self.radius.item()
return (0.0, dx, 0.0, dy)

@override
def _generate_image_and_mask(self, item) -> None:
Expand All @@ -604,7 +639,7 @@ def _generate_image_and_mask(self, item) -> None:
self.data_source["theta"],
self.data_source["dtheta"],
self.data_source[item].astype("float64"),
self.radius,
extents=self.bounds,
return_mask=True,
)
self.data[item] = ImageArray(
Expand Down

0 comments on commit e9d0b3e

Please sign in to comment.