Skip to content

Commit

Permalink
Merge pull request #2585 from mraspaud/fix-caching-error
Browse files Browse the repository at this point in the history
Make caching warn if some of the args are unhashable
  • Loading branch information
djhoese committed Nov 15, 2023
2 parents 6fb1d1e + f189402 commit ce41239
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 30 deletions.
68 changes: 38 additions & 30 deletions satpy/modifiers/angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,35 +138,43 @@ def _zarr_pattern(self, arg_hash, cache_version: Union[None, int, str] = None) -

def __call__(self, *args, cache_dir: Optional[str] = None) -> Any:
"""Call the decorated function."""
new_args = self._sanitize_args_func(*args) if self._sanitize_args_func is not None else args
arg_hash = _hash_args(*new_args, unhashable_types=self._uncacheable_arg_types)
should_cache, cache_dir = self._get_should_cache_and_cache_dir(new_args, cache_dir)
zarr_fn = self._zarr_pattern(arg_hash)
zarr_format = os.path.join(cache_dir, zarr_fn)
zarr_paths = glob(zarr_format.format("*"))
if not should_cache or not zarr_paths:
# use sanitized arguments if we are caching, otherwise use original arguments
args_to_use = new_args if should_cache else args
res = self._func(*args_to_use)
if should_cache and not zarr_paths:
self._warn_if_irregular_input_chunks(args, args_to_use)
self._cache_results(res, zarr_format)
# if we did any caching, let's load from the zarr files
if should_cache:
# re-calculate the cached paths
zarr_paths = sorted(glob(zarr_format.format("*")))
if not zarr_paths:
raise RuntimeError("Data was cached to disk but no files were found")
new_chunks = _get_output_chunks_from_func_arguments(args)
res = tuple(da.from_zarr(zarr_path, chunks=new_chunks) for zarr_path in zarr_paths)
should_cache: bool = satpy.config.get(self._cache_config_key, False)
if not should_cache:
return self._func(*args)

try:
return self._cache_and_read(args, cache_dir)
except TypeError as err:
warnings.warn("Cannot cache function because of unhashable argument: " + str(err), stacklevel=2)
return self._func(*args)

def _cache_and_read(self, args, cache_dir):
sanitized_args = self._sanitize_args_func(*args) if self._sanitize_args_func is not None else args

zarr_file_pattern = self._get_zarr_file_pattern(sanitized_args, cache_dir)
zarr_paths = glob(zarr_file_pattern.format("*"))

if not zarr_paths:
# use sanitized arguments
self._warn_if_irregular_input_chunks(args, sanitized_args)
res_to_cache = self._func(*(sanitized_args))
self._cache_results(res_to_cache, zarr_file_pattern)

# if we did any caching, let's load from the zarr files, so that future calls have the same name
# re-calculate the cached paths
zarr_paths = sorted(glob(zarr_file_pattern.format("*")))
if not zarr_paths:
raise RuntimeError("Data was cached to disk but no files were found")

new_chunks = _get_output_chunks_from_func_arguments(args)
res = tuple(da.from_zarr(zarr_path, chunks=new_chunks) for zarr_path in zarr_paths)
return res

def _get_should_cache_and_cache_dir(self, args, cache_dir: Optional[str]) -> tuple[bool, str]:
should_cache: bool = satpy.config.get(self._cache_config_key, False)
can_cache = not any(isinstance(arg, self._uncacheable_arg_types) for arg in args)
should_cache = should_cache and can_cache
def _get_zarr_file_pattern(self, sanitized_args, cache_dir):
arg_hash = _hash_args(*sanitized_args, unhashable_types=self._uncacheable_arg_types)
zarr_filename = self._zarr_pattern(arg_hash)
cache_dir = self._get_cache_dir_from_config(cache_dir)
return should_cache, cache_dir
return os.path.join(cache_dir, zarr_filename)

@staticmethod
def _get_cache_dir_from_config(cache_dir: Optional[str]) -> str:
Expand All @@ -189,14 +197,14 @@ def _warn_if_irregular_input_chunks(args, modified_args):
stacklevel=3
)

def _cache_results(self, res, zarr_format):
os.makedirs(os.path.dirname(zarr_format), exist_ok=True)
def _cache_results(self, res, zarr_file_pattern):
os.makedirs(os.path.dirname(zarr_file_pattern), exist_ok=True)
new_res = []
for idx, sub_res in enumerate(res):
if not isinstance(sub_res, da.Array):
raise ValueError("Zarr caching currently only supports dask "
f"arrays. Got {type(sub_res)}")
zarr_path = zarr_format.format(idx)
zarr_path = zarr_file_pattern.format(idx)
# See https://github.com/dask/dask/issues/8380
with dask.config.set({"optimization.fuse.active": False}):
new_sub_res = sub_res.to_zarr(zarr_path, compute=False)
Expand Down Expand Up @@ -252,7 +260,7 @@ def _hash_args(*args, unhashable_types=DEFAULT_UNCACHE_TYPES):
hashable_args = []
for arg in args:
if isinstance(arg, unhashable_types):
continue
raise TypeError(f"Unhashable type ({type(arg)}).")
if isinstance(arg, HASHABLE_GEOMETRIES):
arg = hash(arg)
elif isinstance(arg, datetime):
Expand Down
24 changes: 24 additions & 0 deletions satpy/tests/modifier_tests/test_angles.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,30 @@ def _fake_func(shape, chunks):
satpy.config.set(cache_lonlats=True, cache_dir=str(tmp_path)):
_fake_func((5, 5), ((5,), (5,)))

def test_caching_with_array_in_args_warns(self, tmp_path):
"""Test that trying to cache with non-dask arrays fails."""
from satpy.modifiers.angles import cache_to_zarr_if

@cache_to_zarr_if("cache_lonlats")
def _fake_func(array):
return array + 1

with pytest.warns(UserWarning), \
satpy.config.set(cache_lonlats=True, cache_dir=str(tmp_path)):
_fake_func(da.zeros(100))

def test_caching_with_array_in_args_does_not_warn_when_caching_is_not_enabled(self, tmp_path, recwarn):
"""Test that trying to cache with non-dask arrays fails."""
from satpy.modifiers.angles import cache_to_zarr_if

@cache_to_zarr_if("cache_lonlats")
def _fake_func(array):
return array + 1

with satpy.config.set(cache_lonlats=False, cache_dir=str(tmp_path)):
_fake_func(da.zeros(100))
assert len(recwarn) == 0

def test_no_cache_dir_fails(self, tmp_path):
"""Test that 'cache_dir' not being set fails."""
from satpy.modifiers.angles import _get_sensor_angles_from_sat_pos, get_angles
Expand Down

0 comments on commit ce41239

Please sign in to comment.