diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ad7647d..e4ec206 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -59,7 +59,6 @@ jobs: git+https://github.com/dask/distributed \ git+https://github.com/rasterio/rasterio \ git+https://github.com/pydata/bottleneck \ - git+https://github.com/zarr-developers/zarr \ git+https://github.com/pydata/xarray; python -m pip install -e . --no-deps --no-build-isolation; diff --git a/pyproject.toml b/pyproject.toml index 9012fc2..14638ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,8 @@ xfail_strict = true filterwarnings = [ "error", "ignore::rasterio.errors.NotGeoreferencedWarning", + # remove when fixed by xarray + "ignore:__array_wrap__ must accept context and return_scalar arguments:DeprecationWarning:numpy", # dateutil needs a new release # https://github.com/dateutil/dateutil/issues/1314 'ignore:datetime.datetime.utcfromtimestamp\(\) is deprecated and scheduled for removal:DeprecationWarning:dateutil', diff --git a/trollimage/xrimage.py b/trollimage/xrimage.py index 5dba8a5..a4f30c1 100644 --- a/trollimage/xrimage.py +++ b/trollimage/xrimage.py @@ -673,6 +673,9 @@ def _scale_to_dtype(self, data, dtype, fill_value=None): """ attrs = data.attrs.copy() if np.issubdtype(dtype, np.integer): + if np.issubdtype(data, np.bool_): + # convert boolean masks to floats so they can be scaled to the output integer dtype + data = data.astype(np.float64) if np.issubdtype(data, np.integer): # preserve integer data type data = data.clip(np.iinfo(dtype).min, np.iinfo(dtype).max) @@ -1109,6 +1112,10 @@ def crude_stretch(self, min_stretch=None, max_stretch=None): attrs = self.data.attrs offset = -min_stretch * scale_factor + try: + offset = offset.astype(scale_factor.dtype) + except AttributeError: + offset = scale_factor.dtype.type(offset) self.data = np.multiply(self.data, scale_factor, dtype=scale_factor.dtype) + offset self.data.attrs = attrs @@ -1225,7 +1232,7 @@ def stretch_logarithmic(self, factor=100., base="e", min_stretch=None, max_stret log_func = np.log if base == "e" else getattr(np, "log" + base) min_stretch, max_stretch = self._convert_log_minmax_stretch(min_stretch, max_stretch) - b__ = float(crange[1] - crange[0]) / log_func(factor) + b__ = float(crange[1] - crange[0]) / self.data.dtype.type(log_func(factor)) c__ = float(crange[0]) def _band_log(arr, min_input, max_input):