Skip to content

Commit

Permalink
Merge pull request #27 from wankoelias/faster_tiles_exist
Browse files Browse the repository at this point in the history
Faster tiles exist
  • Loading branch information
ungarj committed Oct 21, 2022
2 parents efac98f + e470f1c commit 30c6cff
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
69 changes: 62 additions & 7 deletions mapchete_xarray/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import logging
import math
import os

import croniter
import dateutil
Expand All @@ -29,6 +30,8 @@
"file_extensions": ["zarr"],
}

DEFAULT_TIME_CHUNKSIZE = 8


class OutputDataReader(base.SingleFileOutputReader):

Expand Down Expand Up @@ -277,6 +280,27 @@ def prepare(self, **kwargs):
output_metadata=dump_metadata(self.output_params),
)

def _zarr_chunk_from_xy(self, x, y, on_edge_use="rb"):

# determine row
pixel_y_size = _pixel_y_size(self.bounds.top, self.bounds.bottom, self.shape[0])
tile_y_size = round(
pixel_y_size * self.pyramid.tile_size * self.pyramid.metatiling, 20
)
row = abs(int((self.ds[self.y_axis_name].max() - y) / tile_y_size))
if on_edge_use in ["rt", "lt"] and (self.ds.Y.max() - y) % tile_y_size == 0.0:
row -= 1

pixel_x_size = _pixel_x_size(self.bounds.right, self.bounds.left, self.shape[1])
# determine column
tile_x_size = round(
pixel_x_size * self.pyramid.tile_size * self.pyramid.metatiling, 20
)

col = abs(int((x - self.ds[self.x_axis_name].min()) / tile_x_size))

return row, col

def tiles_exist(self, process_tile=None, output_tile=None):
"""
Check whether output tiles of a tile (either process or output) exists.
Expand All @@ -292,10 +316,29 @@ def tiles_exist(self, process_tile=None, output_tile=None):
-------
exists : bool
"""
bounds = process_tile.bounds if process_tile else output_tile.bounds
for var in self._read(bounds=bounds).values():
if np.any(var != self.nodata):
return True

tile = process_tile or output_tile
zarr_chunk_row, zarr_chunk_col = self._zarr_chunk_from_xy(
tile.bbox.centroid.x, tile.bbox.centroid.y
)

for var in self.ds:

if self.time:

if path_exists(
os.path.join(
self.path,
var,
f"0.{zarr_chunk_row}.{zarr_chunk_col}",
)
):
return True
else:
if path_exists(
os.path.join(self.path, var, f"{zarr_chunk_row}.{zarr_chunk_col}")
):
return True
return False

def is_valid_with_config(self, config):
Expand Down Expand Up @@ -554,6 +597,14 @@ def _get_indexes(self, indexes=None):
return out


def _pixel_x_size(right, left, width):
return (right - left) / width


def _pixel_y_size(top, bottom, height):
return (top - bottom) / -height


def initialize_zarr(
path=None,
bounds=None,
Expand All @@ -575,8 +626,8 @@ def initialize_zarr(

height, width = shape
bounds = Bounds(*bounds)
pixel_x_size = (bounds.right - bounds.left) / width
pixel_y_size = (bounds.top - bounds.bottom) / -height
pixel_x_size = _pixel_x_size(bounds.right, bounds.left, width)
pixel_y_size = _pixel_y_size(bounds.top, bounds.bottom, height)

coord_x = [bounds.left + pixel_x_size / 2 + i * pixel_x_size for i in range(width)]
coord_y = [bounds.top + pixel_y_size / 2 + i * pixel_y_size for i in range(height)]
Expand Down Expand Up @@ -625,7 +676,11 @@ def initialize_zarr(
coords[time_axis_name] = coord_time

output_shape = (len(coord_time), *shape)
output_chunks = (time.get("chunksize", 8), chunksize, chunksize)
output_chunks = (
time.get("chunksize", DEFAULT_TIME_CHUNKSIZE),
chunksize,
chunksize,
)
axis_names = [time_axis_name] + axis_names

else:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@ def test_zarr_process_output_as_input(zarr_process_output_as_input_mapchete):
list(zarr_process_output_as_input_mapchete.mp().compute(concurrency=None))


def test_zarr_process_output_as_input_tile_exists(
zarr_process_output_as_input_mapchete,
):
first_run = list(
zarr_process_output_as_input_mapchete.mp().compute(concurrency=None)
)
assert first_run[0]._result.written is True

second_run = list(
zarr_process_output_as_input_mapchete.mp().compute(concurrency=None)
)
assert second_run[0]._result.written is False


def test_custom_band_names_read_kwargs_no_indexes(output_3d_custom_band_names_mapchete):
mp = output_3d_custom_band_names_mapchete.mp()
tile = output_3d_custom_band_names_mapchete.first_process_tile()
Expand Down

0 comments on commit 30c6cff

Please sign in to comment.