diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 5200b2de..07317d55 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -219,10 +219,17 @@ def zarr_root(self): # We open the archive in read-only mode if it is saved on the Hub if self._zarr_root is None: - if saved_on_hub: - self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+") - else: - self._zarr_root = zarr.open(self.zarr_root_path, "r+") + try: + if saved_on_hub: + self._zarr_root = self.client.open_zarr_file( + self.owner, self.name, self.zarr_root_path, "r+" + ) + else: + self._zarr_root = zarr.open_consolidated(self.zarr_root_path, mode="r+") + except KeyError as error: + raise InvalidDatasetError( + "A Zarr archive associated with a Polaris dataset has to be consolidated." + ) from error return self._zarr_root @computed_field @@ -340,6 +347,13 @@ def to_json(self, destination: str) -> str: if self.zarr_root is not None: dest = zarr.open(zarr_archive, "w") zarr.copy_all(source=self.zarr_root, dest=dest) + + # Copy the .zmetadata file + # To track discussions on whether this should be done by copy_all() + # see https://github.com/zarr-developers/zarr-python/issues/1731 + zmetadata_content = self.zarr_root.store.store[".zmetadata"] + dest.store[".zmetadata"] = zmetadata_content + serialized["zarr_root_path"] = zarr_archive self.table.to_parquet(table_path) diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py index b6dd48e3..c0329e9d 100644 --- a/polaris/dataset/_factory.py +++ b/polaris/dataset/_factory.py @@ -215,6 +215,7 @@ def add_from_file(self, path: str): def build(self) -> Dataset: """Returns a Dataset based on the current state of the factory.""" + zarr.consolidate_metadata(self.zarr_root.store) return Dataset( table=self._table, annotations=self._annotations, diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 9f7e2292..c7a2feb2 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -358,7 +358,7 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b return Dataset(**response) def open_zarr_file( - self, owner: Union[str, HubOwner], name: str, path: str, mode: IOMode + self, owner: Union[str, HubOwner], name: str, path: str, mode: IOMode, as_consolidated: bool = True ) -> zarr.hierarchy.Group: """Open a Zarr file from a Polaris dataset @@ -367,10 +367,14 @@ def open_zarr_file( name: Name of the dataset. path: Path to the Zarr file within the dataset. mode: The mode in which the file is opened. + as_consolidated: Whether to open the store with consolidated metadata for optimized reading. This is only applicable in 'r' and 'r+' modes. Returns: The Zarr object representing the dataset. """ + if as_consolidated and mode not in ["r", "r+"]: + raise ValueError("Consolidated archives can only be used with 'r' or 'r+' mode.") + polaris_fs = PolarisFileSystem( polaris_client=self, dataset_owner=owner, @@ -379,6 +383,8 @@ def open_zarr_file( try: store = zarr.storage.FSStore(path, fs=polaris_fs) + if mode in ["r", "r+"] and as_consolidated: + return zarr.open_consolidated(store, mode=mode) return zarr.open(store, mode=mode) except Exception as e: @@ -587,13 +593,21 @@ def upload_dataset( if dataset.zarr_root is not None: with tmp_attribute_change(self.settings, "default_timeout", timeout): # Copy the Zarr archive to the hub - # This does not copy the consolidated data dest = self.open_zarr_file( owner=dataset.owner, name=dataset.name, path=dataset_json["zarrRootPath"], mode="w", + as_consolidated=False, ) + + # Locally consolidate Zarr archive metadata. Future updates on handling consolidated + # metadata based on Zarr developers' recommendations can be tracked at: + # https://github.com/zarr-developers/zarr-python/issues/1731 + zarr.consolidate_metadata(dataset.zarr_root.store) + zmetadata_content = dataset.zarr_root.store.store[".zmetadata"] + dest.store[".zmetadata"] = zmetadata_content + logger.info("Copying Zarr archive to the Hub. This may take a while.") zarr.copy_all(source=dataset.zarr_root, dest=dest, log=logger.info) diff --git a/tests/conftest.py b/tests/conftest.py index 8874e473..8b9fd7cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -70,10 +70,11 @@ def test_dataset(test_data, test_org_owner): @pytest.fixture(scope="function") def zarr_archive(tmp_path): - tmp_path = fs.join(str(tmp_path), "data.zarr") - root = zarr.open_group(tmp_path, mode="w") + tmp_path = fs.join(tmp_path, "data.zarr") + root = zarr.open(tmp_path, mode="w") root.array("A", data=np.random.random((100, 2048))) root.array("B", data=np.random.random((100, 2048))) + zarr.consolidate_metadata(root.store) return tmp_path diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 26da2eb8..a1790095 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -47,6 +47,7 @@ def test_load_data(tmp_path, with_slice, with_caching): root = zarr.open(zarr_path, "w") root.array("A", data=arr) + zarr.consolidate_metadata(root.store) path = "A#0:5" if with_slice else "A#0" table = pd.DataFrame({"A": [path]}, index=[0])