Skip to content

Commit

Permalink
Merge pull request #31 from ungarj/mpath_compatibility
Browse files Browse the repository at this point in the history
MPath compatibility
  • Loading branch information
ungarj committed Jun 13, 2023
2 parents c6dc42c + 3def665 commit 7396a27
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 72 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [ "3.7", "3.8", "3.9" ]
python-version: [ "3.8", "3.9", "3.10" ]
os: [ "ubuntu-20.04", "ubuntu-22.04" ]

steps:
Expand Down
6 changes: 5 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v1.2.3
hooks:
- id: flake8
- id: flake8
- repo: https://github.com/PyCQA/autoflake
rev: v2.1.1
hooks:
- id: autoflake
5 changes: 1 addition & 4 deletions mapchete_xarray/_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,10 @@

import numpy as np
import xarray as xr
import zarr
from mapchete.config import snap_bounds
from mapchete.formats import base, load_metadata
from mapchete.io import path_exists
from mapchete.io.vector import reproject_geometry
from shapely.geometry import box
from zarr.storage import FSStore


class InputData(base.InputData):
Expand All @@ -22,7 +19,7 @@ def __init__(self, input_params, **kwargs):
"""Initialize."""
super().__init__(input_params, **kwargs)
self.path = input_params["path"]
if not path_exists(self.path): # pragma: no cover
if not self.path.exists(): # pragma: no cover
raise FileNotFoundError(f"path {self.path} does not exist")
mapchete_params = self.ds.attrs.get("mapchete")
if mapchete_params is None: # pragma: no cover
Expand Down
37 changes: 12 additions & 25 deletions mapchete_xarray/_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import logging
import math
import os

import croniter
import dateutil
Expand All @@ -15,8 +14,8 @@
from mapchete.errors import MapcheteConfigError
from mapchete.formats import base
from mapchete.formats.tools import compare_metadata_params, dump_metadata, load_metadata
from mapchete.io import fs_from_path, path_exists
from mapchete.io.raster import bounds_to_ranges, create_mosaic, extract_from_array
from mapchete.path import MPath
from rasterio.transform import from_origin
from tilematrix import Bounds
from zarr.storage import FSStore
Expand Down Expand Up @@ -51,7 +50,6 @@ def __init__(self, output_params, *args, **kwargs):
self.path = output_params["path"]
if not self.path.endswith(self.file_extension):
raise MapcheteConfigError("output path must end with .zarr")
self.fs = fs_from_path(self.path)
self.output_params = output_params
self.zoom = output_params["delimiters"]["zoom"][0]

Expand Down Expand Up @@ -250,9 +248,9 @@ def __init__(self, output_params, *args, **kwargs):
super().__init__(output_params, *args, **kwargs)

def prepare(self, **kwargs):
if path_exists(self.path):
if self.path.exists():
# verify it is compatible with our output parameters / chunking
archive = zarr.open(FSStore(f"{self.path}"))
archive = zarr.open(FSStore(f"{self.path}", fs=self.path.fs))
mapchete_params = archive.attrs.get("mapchete")
if mapchete_params is None: # pragma: no cover
raise TypeError(
Expand Down Expand Up @@ -322,19 +320,10 @@ def tiles_exist(self, process_tile=None, output_tile=None):
for var in self.ds:

if self.time:

if path_exists(
os.path.join(
self.path,
var,
f"0.{zarr_chunk_row}.{zarr_chunk_col}",
)
):
if (self.path / var / f"0.{zarr_chunk_row}.{zarr_chunk_col}").exists():
return True
else:
if path_exists(
os.path.join(self.path, var, f"{zarr_chunk_row}.{zarr_chunk_col}")
):
if (self.path / var / f"{zarr_chunk_row}.{zarr_chunk_col}").exists():
return True
return False

Expand All @@ -359,7 +348,7 @@ def is_valid_with_config(self, config):
"when using a time axis, please specify the time stamps either through "
"'pattern' or 'steps'"
)
return validate_values(config, [("path", str)])
return validate_values(config, [("path", (str, MPath))])

def write(self, process_tile, data):
"""
Expand All @@ -385,7 +374,7 @@ def write(self, process_tile, data):

def write_zarr(ds, region):
ds.to_zarr(
FSStore(self.path),
FSStore(str(self.path), fs=self.path.fs),
mode="r+",
compute=True,
safe_chunks=True,
Expand Down Expand Up @@ -618,7 +607,8 @@ def initialize_zarr(
area_or_point="Area",
output_metadata=None,
):
if path_exists(path): # pragma: no cover
path = MPath.from_inp(path)
if path.exists(): # pragma: no cover
raise IOError(f"cannot initialize zarr storage as path already exists: {path}")

height, width = shape
Expand Down Expand Up @@ -688,15 +678,15 @@ def initialize_zarr(
# write zarr
ds = xr.Dataset(coords=coords)
ds.to_zarr(
FSStore(path),
FSStore(path, fs=path.fs),
compute=False,
encoding={var: {"_FillValue": fill_value} for var in ds.data_vars},
safe_chunks=True,
)

# add GDAL metadata for each band
for band_name in band_names:
store = FSStore(f"{path}/{band_name}")
store = FSStore(f"{path}/{band_name}", fs=path.fs)
zarr.creation.create(
shape=output_shape,
chunks=output_chunks,
Expand All @@ -716,8 +706,5 @@ def initialize_zarr(

except Exception: # pragma: no cover
# remove leftovers if something failed during initialization
try:
fs_from_path(path).rm(path, recursive=True)
except FileNotFoundError:
pass
path.rm(recursive=True, ignore_errors=True)
raise
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dask
mapchete[s3]>=2022.7.0
mapchete[s3]>=2023.6.3
xarray
zarr
python-dateutil
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def parse_requirements(file):
"Intended Audience :: Developers",
"Topic :: Scientific/Engineering :: GIS",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
],
setup_requires=["pytest-runner"],
tests_require=["pytest"],
Expand Down
62 changes: 32 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,37 @@
from tempfile import TemporaryDirectory

import pytest
from mapchete.io import fs_from_path
from mapchete.path import MPath
from mapchete.testing import ProcessFixture

SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__))
TESTDATA_DIR = os.path.join(SCRIPT_DIR, "testdata")
S3_TEMP_DIR = "s3://mapchete-test/tmp/" + uuid.uuid4().hex
SCRIPT_DIR = MPath(os.path.dirname(os.path.realpath(__file__)))
TESTDATA_DIR = SCRIPT_DIR / "testdata"
S3_TEMP_DIR = MPath("s3://mapchete-test/tmp/" + uuid.uuid4().hex)


@pytest.fixture
def mp_s3_tmpdir():
"""Setup and teardown temporary directory."""
fs = fs_from_path(S3_TEMP_DIR)
S3_TEMP_DIR.rm(recursive=True, ignore_errors=True)
S3_TEMP_DIR.makedirs()
yield S3_TEMP_DIR
S3_TEMP_DIR.rm(recursive=True, ignore_errors=True)

def _cleanup():
try:
fs.rm(S3_TEMP_DIR, recursive=True)
except FileNotFoundError:
pass

_cleanup()
yield S3_TEMP_DIR
_cleanup()
@pytest.fixture(autouse=True)
def mp_tmpdir():
"""Setup and teardown temporary directory."""
with TemporaryDirectory() as tempdir_path:
tempdir = MPath(tempdir_path)
tempdir.makedirs()
yield tempdir


@pytest.fixture(scope="session")
def written_output():
with TemporaryDirectory() as tempdir:
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_4d.mapchete"), output_tempdir=tempdir
TESTDATA_DIR / "output_4d.mapchete", output_tempdir=tempdir
) as example:
data_tile = next(example.mp().get_process_tiles(5))
example.mp().batch_process(tile=data_tile.id)
Expand All @@ -41,40 +43,40 @@ def written_output():
@pytest.fixture
def convert_to_zarr_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "convert_to_zarr.mapchete"),
TESTDATA_DIR / "convert_to_zarr.mapchete",
) as example:
yield example


@pytest.fixture
def output_3d_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_3d.mapchete"),
TESTDATA_DIR / "output_3d.mapchete",
) as example:
yield example


@pytest.fixture
def output_3d_custom_band_names_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_3d_custom_band_names.mapchete"),
TESTDATA_DIR / "output_3d_custom_band_names.mapchete",
) as example:
yield example


@pytest.fixture
def output_3d_numpy_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_3d_numpy.mapchete"),
TESTDATA_DIR / "output_3d_numpy.mapchete",
) as example:
yield example


@pytest.fixture
def output_4d_s3_mapchete():
def output_4d_s3_mapchete(mp_s3_tmpdir):
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_4d.mapchete"),
output_tempdir=os.path.join(S3_TEMP_DIR),
TESTDATA_DIR / "output_4d.mapchete",
output_tempdir=mp_s3_tmpdir,
output_suffix=".zarr",
) as example:
yield example
Expand All @@ -83,39 +85,39 @@ def output_4d_s3_mapchete():
@pytest.fixture
def output_4d_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_4d.mapchete"),
TESTDATA_DIR / "output_4d.mapchete",
) as example:
yield example


@pytest.fixture
def output_4d_numpy_mapchete():
with ProcessFixture(
os.path.join(TESTDATA_DIR, "output_4d_numpy.mapchete"),
TESTDATA_DIR / "output_4d_numpy.mapchete",
) as example:
yield example


@pytest.fixture
def zarr_as_input_mapchete(tmp_path):
def zarr_as_input_mapchete(mp_tmpdir):
with ProcessFixture(
os.path.join(TESTDATA_DIR, "zarr_as_input.mapchete"),
output_tempdir=tmp_path,
TESTDATA_DIR / "zarr_as_input.mapchete",
output_tempdir=mp_tmpdir,
output_suffix=".zarr",
) as example:
yield example


@pytest.fixture
def zarr_process_output_as_input_mapchete(tmp_path):
def zarr_process_output_as_input_mapchete(mp_tmpdir):
with ProcessFixture(
os.path.join(TESTDATA_DIR, "zarr_process_output_as_input.mapchete"),
output_tempdir=tmp_path,
TESTDATA_DIR / "zarr_process_output_as_input.mapchete",
output_tempdir=mp_tmpdir,
output_suffix=".zarr",
) as example:
yield example


@pytest.fixture
def example_zarr():
return os.path.join(TESTDATA_DIR, "example.zarr")
return TESTDATA_DIR / "example.zarr"
15 changes: 6 additions & 9 deletions tests/test_zarr_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@

import dateutil
import zarr
from mapchete.io import fs_from_path
from rasterio.crs import CRS
from tilematrix import TilePyramid

from mapchete_xarray._output import initialize_zarr


def test_initialize_zarr(tmp_path):
out_path = os.path.join(tmp_path, "test.zarr")
def test_initialize_zarr(mp_tmpdir):
out_path = mp_tmpdir / "test.zarr"
tp = TilePyramid("geodetic")
initialize_zarr(
path=out_path,
Expand All @@ -23,11 +22,10 @@ def test_initialize_zarr(tmp_path):
band_names=["Band1", "Band2", "Band3"],
dtype="uint8",
)
fs = fs_from_path(out_path)
bands = ["Band1", "Band2", "Band3"]
axes = ["X", "Y"]
required_files = [".zgroup", ".zmetadata"] + bands + axes
ls = fs.ls(out_path)
ls = out_path.ls()
for required_file in required_files:
for file in ls:
if file.endswith(required_file):
Expand Down Expand Up @@ -59,8 +57,8 @@ def test_initialize_zarr(tmp_path):
assert 10 < coord < 20


def test_initialize_zarr_time(tmp_path):
out_path = os.path.join(tmp_path, "test.zarr")
def test_initialize_zarr_time(mp_tmpdir):
out_path = mp_tmpdir / "test.zarr"
tp = TilePyramid("geodetic")
initialize_zarr(
path=out_path,
Expand All @@ -78,11 +76,10 @@ def test_initialize_zarr_time(tmp_path):
band_names=["red", "green", "blue"],
dtype="uint8",
)
fs = fs_from_path(out_path)
bands = ["red", "green", "blue"]
axes = ["time", "X", "Y"]
required_files = [".zgroup", ".zmetadata"] + bands + axes
ls = fs.ls(out_path)
ls = out_path.ls()
for required_file in required_files:
for file in ls:
if file.endswith(required_file):
Expand Down

0 comments on commit 7396a27

Please sign in to comment.