Skip to content

Commit

Permalink
Merge pull request #25 from ungarj/band_names
Browse files Browse the repository at this point in the history
enable defining custom band names
  • Loading branch information
ungarj authored Jul 12, 2022
2 parents 1bd75d5 + 8a4bb50 commit 83a3aef
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 58 deletions.
59 changes: 41 additions & 18 deletions mapchete_xarray/_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,49 @@
class InputData(base.InputData):
"""In case this driver is used when being a readonly input to another process."""

_ds = None

def __init__(self, input_params, **kwargs):
"""Initialize."""
super().__init__(input_params, **kwargs)
self.path = input_params["path"]
if not path_exists(self.path): # pragma: no cover
raise FileNotFoundError(f"path {self.path} does not exist")
archive = zarr.open(FSStore(f"{self.path}"))
mapchete_params = archive.attrs.get("mapchete")
mapchete_params = self.ds.attrs.get("mapchete")
if mapchete_params is None: # pragma: no cover
raise TypeError(
f"zarr archive at {self.path} exists but does not hold mapchete metadata"
)
metadata = load_metadata(mapchete_params)
self.zarr_pyramid = metadata["pyramid"]
mapchete_metadata = load_metadata(mapchete_params)
self.zarr_pyramid = mapchete_metadata["pyramid"]
if self.zarr_pyramid.crs != self.pyramid.crs: # pragma: no cover
raise NotImplementedError(
f"single zarr output ({self.zarr_pyramid.crs}) cannot be reprojected to different CRS ({self.pyramid.crs})"
f"single zarr output ({self.zarr_pyramid.crs}) cannot be reprojected to "
f"different CRS ({self.pyramid.crs})"
)
self._bounds = snap_bounds(
bounds=metadata["driver"]["delimiters"]["process_bounds"],
bounds=mapchete_metadata["driver"]["delimiters"]["process_bounds"],
pyramid=self.zarr_pyramid,
zoom=metadata["driver"]["delimiters"]["zoom"][0],
zoom=mapchete_metadata["driver"]["delimiters"]["zoom"][0],
)
self.x_axis_name = mapchete_metadata["driver"].get("x_axis_name", "X")
self.y_axis_name = mapchete_metadata["driver"].get("y_axis_name", "Y")
self.time_axis_name = mapchete_metadata["driver"].get("time_axis_name", "time")
self.time = mapchete_metadata["driver"].get("time", {})
self.band_names = mapchete_metadata["driver"].get(
"band_names", [v for v in self.ds.data_vars]
)
self.x_axis_name = metadata["driver"].get("x_axis_name", "X")
self.y_axis_name = metadata["driver"].get("y_axis_name", "Y")
self.time = metadata["driver"].get("time", {})

@property
def ds(self):
if self._ds is None:
self._ds = xr.open_zarr(
self.path,
mask_and_scale=False,
consolidated=True,
chunks=None,
)
return self._ds

def open(self, tile, **kwargs):
"""
Expand All @@ -61,7 +78,9 @@ def open(self, tile, **kwargs):
path=self.path,
x_axis_name=self.x_axis_name,
y_axis_name=self.y_axis_name,
time_axis_name=self.time_axis_name,
time=self.time,
band_names=self.band_names,
bbox=self.bbox(),
**kwargs,
)
Expand Down Expand Up @@ -105,7 +124,9 @@ def __init__(
path=None,
x_axis_name=None,
y_axis_name=None,
time_axis_name=None,
time=None,
band_names=None,
bbox=None,
**kwargs,
):
Expand All @@ -114,7 +135,9 @@ def __init__(
self.tile = tile
self.x_axis_name = x_axis_name
self.y_axis_name = y_axis_name
self.time_axis_name = time_axis_name
self.time = time
self.band_names = band_names
self.bbox = bbox
self._ds = None

Expand All @@ -131,11 +154,12 @@ def ds(self):

@property
def bands(self):
return [v for v in self.ds.data_vars]
"""Return band names in correct order."""
return self.band_names

def _get_indexes(self, indexes=None):
"""Return a list of band names (i.e. Zarr data variable names)."""
if indexes is None: # pragma: no cover
if indexes is None:
return self.bands
indexes = indexes if isinstance(indexes, list) else [indexes]
out = []
Expand Down Expand Up @@ -168,17 +192,16 @@ def read(

if self.time:
if start_time or end_time:
selector["time"] = slice(
selector[self.time_axis_name] = slice(
start_time or self.time.get("start"),
end_time or self.time.get("end"),
)
elif timestamps:
selector["time"] = np.array(timestamps, dtype=np.datetime64)
selector[self.time_axis_name] = np.array(
timestamps, dtype=np.datetime64
)

if indexes:
return self.ds[self._get_indexes(indexes)].sel(**selector)
else:
return self.ds.sel(**selector)
return self.ds[self._get_indexes(indexes)].sel(**selector)

def is_empty(self): # pragma: no cover
"""
Expand Down
Loading

0 comments on commit 83a3aef

Please sign in to comment.