From 4e37683720539ddb4440898181d84e423a3b3b00 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 8 May 2024 15:37:33 -0400 Subject: [PATCH 01/29] First implementation of the zarr checksum Mostly uses the code from the zarr-checksum library --- LICENSE | 2 +- NOTICE | 13 ++ docs/api/dataset.md | 6 + polaris/dataset/_dataset.py | 21 +- polaris/dataset/_subset.py | 6 +- polaris/dataset/zarr/__init__.py | 3 +- polaris/dataset/zarr/_checksum.py | 314 ++++++++++++++++++++++++++++++ polaris/dataset/zarr/_memmap.py | 2 +- polaris/utils/errors.py | 4 + tests/test_dataset.py | 15 +- tests/test_zarr_checksum.py | 117 +++++++++++ 11 files changed, 485 insertions(+), 18 deletions(-) create mode 100644 NOTICE create mode 100644 polaris/dataset/zarr/_checksum.py create mode 100644 tests/test_zarr_checksum.py diff --git a/LICENSE b/LICENSE index f048b6a9..7dd1e2be 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2021 Valence + Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/NOTICE b/NOTICE new file mode 100644 index 00000000..564a6a7c --- /dev/null +++ b/NOTICE @@ -0,0 +1,13 @@ +Copyright 2023 Valence Labs + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/docs/api/dataset.md b/docs/api/dataset.md index ae107961..ec1087e6 100644 --- a/docs/api/dataset.md +++ b/docs/api/dataset.md @@ -8,4 +8,10 @@ options: filters: ["!^_"] +--- + +::: polaris.dataset.zarr + options: + filters: ["!^_"] + --- \ No newline at end of file diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 596db4ec..041fd2f8 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -20,7 +20,7 @@ from polaris._artifact import BaseArtifactModel from polaris.dataset._adapters import Adapter from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore +from polaris.dataset.zarr import MemoryMappedDirectoryStore, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html @@ -143,7 +143,7 @@ def _validate_model(cls, m: "Dataset"): # Verify the checksum # NOTE (cwognum): Is it still reasonable to always verify this as the dataset size grows? actual = m.md5sum - expected = cls._compute_checksum(m.table) + expected = cls._compute_checksum(m.table, m.zarr_root_path) if actual is None: m.md5sum = expected @@ -171,16 +171,17 @@ def _serialize_adapters(self, value: List[Adapter]): return {k: v.name for k, v in value.items()} @staticmethod - def _compute_checksum(table): + def _compute_checksum( + table: pd.DataFrame, + zarr_root_path: Optional[str] = None, + ): """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. 1. Is not sensitive to the ordering of the columns or rows in the table. 2. Purposefully does not include the meta-data (source, description, name, annotations). - 3. For any pointer column, it uses a hash of the path instead of the file contents. - This is a limitation, but probably a reasonable assumption that helps practicality. - A big downside is that as the dataset is saved elsewhere, the hash changes. + 3. Includes a hash for the Zarr column """ hash_fn = md5() @@ -191,6 +192,14 @@ def _compute_checksum(table): # Use the sum of the row-wise hashes s.t. the hash is insensitive to the row-ordering table_hash = pd.util.hash_pandas_object(df, index=False).sum() hash_fn.update(table_hash) + print(hash_fn.hexdigest()) + + # If the Zarr arhive exists, we hash its contents too. + if zarr_root_path is not None: + zarr_hash = compute_zarr_checksum(zarr_root_path) + hash_fn.update(zarr_hash.encode()) + print(zarr_root_path, zarr_hash) + print(hash_fn.hexdigest()) checksum = hash_fn.hexdigest() return checksum diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 0b0e02db..2cacc4dd 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -76,6 +76,10 @@ def __init__( self._adapters = adapters self._featurization_fn = featurization_fn + + # NOTE (cwognum): Note to future self. As we're starting to think about competition-style benchmarks, + # we will likely split up datasets. In that case, this default iloc_to_loc mapping won't work. + # By that time, we should probably be able to overwrite this mapping. self._iloc_to_loc = self.dataset.table.index # For the iterator implementation @@ -226,4 +230,4 @@ def __next__(self): item = self[self._pointer] self._pointer += 1 - return item + return item # diff --git a/polaris/dataset/zarr/__init__.py b/polaris/dataset/zarr/__init__.py index cb984e02..78579ce4 100644 --- a/polaris/dataset/zarr/__init__.py +++ b/polaris/dataset/zarr/__init__.py @@ -1,3 +1,4 @@ +from ._checksum import compute_zarr_checksum from ._memmap import MemoryMappedDirectoryStore -__all__ = ["MemoryMappedDirectoryStore"] +__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum"] diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py new file mode 100644 index 00000000..04be933b --- /dev/null +++ b/polaris/dataset/zarr/_checksum.py @@ -0,0 +1,314 @@ +""" +The code in this file is based on the zarr-checksum package + +Mainted by Jacob Nesbitt, released under the DANDI org on Github +and with Kitware, Inc. credited as the author. This code is released +with the Apache 2.0 license. + +See also: https://github.com/dandi/zarr_checksum + +Instead of adding the package as a dependency, we opted to copy over the code +because it is a small and self-contained module that we will want to alter to +support our Polaris code base. + +NOTE: We have made some modifications to the original code. + +---- + +Copyright 2023 Kitware, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import hashlib +import heapq +import os +import re +from dataclasses import asdict, dataclass, field +from functools import total_ordering +from json import dumps +from pathlib import Path +from typing import Optional + +import datamol as dm +import zarr +import zarr.errors +from fsspec import AbstractFileSystem +from tqdm import tqdm + +from polaris.utils.errors import InvalidZarrChecksum + +ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)--([0-9]+)" + + +def compute_zarr_checksum(zarr_root_path: str, fs: Optional[AbstractFileSystem] = None) -> str: + r""" + Implements an algorithm to compute the Zarr checksum. + + Warning: This checksum is sensitive to Zarr configuration. + This checksum is not insensitive to change in the Zarr structure. For example, if you change the chunk size, + the checksum will also change. + + To understand how this works, consider the following directory structure: + + . (root) + / \ + a c + / + b + + Within zarr, this would for example be: + + - `root`: A Zarr Group with a single Array. + - `a`: A Zarr Array + - `b`: A single chunk of the Zarr Array + - `c`: A metadata file (i.e. .zarray, .zattrs or .zgroup) + + To compute the checksum, we first find all the trees in the node, in this case b and c. + We compute the hash of the content (the raw bytes) for each of these files. + + We then work our way up the tree. For any node (directory), we find all children of that node. + In an sorted order, we then serialize a list with - for each of the children - the checksum, size, and number of children. + The hash of the directory is then equal to the hash of the serialized JSON. + + The Polaris implementation is heavily based on the [`zarr-checksum` package](https://github.com/dandi/zarr_checksum). + This method is the biggest deviation of the original code. + """ + + if fs is None: + # Try guess the filesystem if it's not specified + fs = dm.utils.fs.get_mapper(zarr_root_path).fs + + # Get the protocol of the path + protocol = dm.utils.fs.get_protocol(zarr_root_path, fs) + + # For a local path, we extend the path to an absolute path + # Otherwise, we assume the path is already absolute + if protocol == "file": + zarr_root_path = os.path.expandvars(zarr_root_path) + zarr_root_path = os.path.expanduser(zarr_root_path) + zarr_root_path = os.path.abspath(zarr_root_path) + + # Make sure the path exists and is a Zarr archive + zarr.open_group(zarr_root_path, mode="r") + + # Generate the checksum + tree = ZarrChecksumTree() + + # Find all files in the root + leaves = fs.find(zarr_root_path, detail=True) + + for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): + path = file["name"] + + relpath = path.removeprefix(zarr_root_path) + relpath = relpath.lstrip("/") + relpath = Path(relpath) + + size = file["size"] + + # Compute md5sum of file + md5sum = hashlib.md5() + with fs.open(path, "rb") as f: + for chunk in iter(lambda: f.read(8192), b""): + md5sum.update(chunk) + digest = md5sum.hexdigest() + + # Yield file + tree.add_leaf( + path=relpath, + size=size, + digest=digest, + ) + + # Compute digest + return tree.process().digest + + +# Pydantic models aren't used for performance reasons +@dataclass +class ZarrChecksumNode: + """Represents the aggregation of zarr files at a specific path in the tree.""" + + path: Path + checksums: "ZarrChecksumManifest" + + def __lt__(self, other: "ZarrChecksumNode") -> bool: + return str(self.path) < str(other.path) + + +class ZarrChecksumTree: + """A tree that represents the checksummed files in a zarr.""" + + def __init__(self) -> None: + self._heap: list[tuple[int, ZarrChecksumNode]] = [] + self._path_map: dict[Path, ZarrChecksumNode] = {} + + @property + def empty(self) -> bool: + return len(self._heap) == 0 + + def _add_path(self, key: Path) -> None: + node = ZarrChecksumNode(path=key, checksums=ZarrChecksumManifest()) + + # Add link to node + self._path_map[key] = node + + # Add node to heap with length (negated to representa max heap) + length = len(key.parents) + heapq.heappush(self._heap, (-1 * length, node)) + + def _get_path(self, key: Path) -> ZarrChecksumNode: + if key not in self._path_map: + self._add_path(key) + + return self._path_map[key] + + def add_leaf(self, path: Path, size: int, digest: str) -> None: + """Add a leaf file to the tree.""" + parent_node = self._get_path(path.parent) + parent_node.checksums.files.append(ZarrChecksum(name=path.name, size=size, digest=digest)) + + def add_node(self, path: Path, size: int, digest: str) -> None: + """Add an internal node to the tree.""" + parent_node = self._get_path(path.parent) + parent_node.checksums.directories.append( + ZarrChecksum( + name=path.name, + size=size, + digest=digest, + ) + ) + + def pop_deepest(self) -> ZarrChecksumNode: + """Find the deepest node in the tree, and return it.""" + _, node = heapq.heappop(self._heap) + del self._path_map[node.path] + + return node + + def process(self) -> "ZarrDirectoryDigest": + """Process the tree, returning the resulting top level digest.""" + # Begin with empty root node, so if no files are present, the empty checksum is returned + node = ZarrChecksumNode(path=Path("."), checksums=ZarrChecksumManifest()) + while not self.empty: + # Pop the deepest directory available + node = self.pop_deepest() + + # If we have reached the root node, then we're done. + if node.path == Path(".") or node.path == Path("/"): + break + + # Add the parent of this node to the tree + directory_digest = node.checksums.generate_digest() + self.add_node( + path=node.path, + size=directory_digest.size, + digest=directory_digest.digest, + ) + + # Return digest + return node.checksums.generate_digest() + + +@dataclass +class ZarrDirectoryDigest: + """The data that can be serialized to / deserialized from a checksum string.""" + + md5: str + count: int + size: int + + @classmethod + def parse(cls, checksum: str | None) -> "ZarrDirectoryDigest": + if checksum is None: + return cls.parse(EMPTY_CHECKSUM) + + match = re.match(ZARR_DIGEST_PATTERN, checksum) + if match is None: + raise InvalidZarrChecksum() + + md5, count, size = match.groups() + return cls(md5=md5, count=int(count), size=int(size)) + + def __str__(self) -> str: + return self.digest + + @property + def digest(self) -> str: + return f"{self.md5}-{self.count}--{self.size}" + + +@total_ordering +@dataclass +class ZarrChecksum: + """ + A checksum for a single file/directory in a zarr file. + + Every file and directory in a zarr archive has a name, digest, and size. + Leaf nodes are created by providing an md5 digest. + Internal nodes (directories) have a digest field that is a zarr directory digest + + This class is serialized to JSON, and as such, key order should not be modified. + """ + + digest: str + name: str + size: int + + # To make this class sortable + def __lt__(self, other: "ZarrChecksum") -> bool: + return self.name < other.name + + +@dataclass +class ZarrChecksumManifest: + """ + A set of file and directory checksums. + + This is the data hashed to calculate the checksum of a directory. + """ + + directories: list[ZarrChecksum] = field(default_factory=list) + files: list[ZarrChecksum] = field(default_factory=list) + + @property + def is_empty(self) -> bool: + return not (self.files or self.directories) + + def generate_digest(self) -> ZarrDirectoryDigest: + """Generate an aggregated digest for the provided files/directories.""" + # Ensure sorted first + self.files.sort() + self.directories.sort() + + # Aggregate total file count + count = len(self.files) + sum( + ZarrDirectoryDigest.parse(checksum.digest).count for checksum in self.directories + ) + + # Aggregate total size + size = sum(file.size for file in self.files) + sum(directory.size for directory in self.directories) + + # Serialize json without any spacing + json = dumps(asdict(self), separators=(",", ":")) + + # Generate digest + md5 = hashlib.md5(json.encode("utf-8")).hexdigest() + + # Construct and return + return ZarrDirectoryDigest(md5=md5, count=count, size=size) + + +# The "null" zarr checksum +EMPTY_CHECKSUM = ZarrChecksumManifest().generate_digest().digest diff --git a/polaris/dataset/zarr/_memmap.py b/polaris/dataset/zarr/_memmap.py index 55d13d94..6a3f2da3 100644 --- a/polaris/dataset/zarr/_memmap.py +++ b/polaris/dataset/zarr/_memmap.py @@ -6,7 +6,7 @@ class MemoryMappedDirectoryStore(zarr.DirectoryStore): """ A Zarr Store to open chunks as memory-mapped files. - See https://github.com/zarr-developers/zarr-python/issues/1245 + See also [this Github issue](https://github.com/zarr-developers/zarr-python/issues/1245). Memory mapping leverages low-level OS functionality to reduce the time it takes to read the content of a file by directly mapping to memory. diff --git a/polaris/utils/errors.py b/polaris/utils/errors.py index 1cb46581..5d4494c1 100644 --- a/polaris/utils/errors.py +++ b/polaris/utils/errors.py @@ -33,3 +33,7 @@ class TestAccessError(Exception): __test__ = False pass + + +class InvalidZarrChecksum(Exception): + pass diff --git a/tests/test_dataset.py b/tests/test_dataset.py index a78919f2..47a34091 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -132,10 +132,10 @@ def test_dataset_from_json(test_dataset, tmpdir): path = fs.join(str(tmpdir), "dataset.json") new_dataset = Dataset.from_json(path) - assert _equality_test(test_dataset, new_dataset) + assert test_dataset == new_dataset new_dataset = load_dataset(path) - assert _equality_test(test_dataset, new_dataset) + assert test_dataset == new_dataset def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): @@ -152,24 +152,23 @@ def test_dataset_from_zarr_to_json_and_back(zarr_archive, tmpdir): path = dataset.to_json(json_dir) new_dataset = Dataset.from_json(path) - assert _equality_test(dataset, new_dataset) + assert dataset == new_dataset new_dataset = load_dataset(path) - assert _equality_test(dataset, new_dataset) + assert dataset == new_dataset def test_dataset_caching(zarr_archive, tmpdir): """Test whether the dataset remains the same after caching.""" - archive = zarr_archive - original_dataset = create_dataset_from_file(archive, tmpdir.join("original1")) - cached_dataset = create_dataset_from_file(archive, tmpdir.join("original2")) + original_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original1")) + cached_dataset = create_dataset_from_file(zarr_archive, tmpdir.join("original2")) assert original_dataset == cached_dataset cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath) assert cached_dataset.zarr_root_path.startswith(cache_dir) - assert _equality_test(cached_dataset, original_dataset) + assert cached_dataset == original_dataset def test_dataset_index(): diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py new file mode 100644 index 00000000..7926d6ab --- /dev/null +++ b/tests/test_zarr_checksum.py @@ -0,0 +1,117 @@ +""" +The code in this file is based on the zarr-checksum package + +Mainted by Jacob Nesbitt, released under the DANDI org on Github +and with Kitware, Inc. credited as the author. This code is released +with the Apache 2.0 license. + +See also: https://github.com/dandi/zarr_checksum + +Instead of adding the package as a dependency, we opted to copy over the code +because it is a small and self-contained module that we will want to alter to +support our Polaris code base. + +NOTE: We have made some modifications to the original code. + +---- + +Copyright 2023 Kitware, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +from pathlib import Path +from shutil import copytree + +import pytest +import zarr + +from polaris.dataset.zarr._checksum import ( + EMPTY_CHECKSUM, + InvalidZarrChecksum, + ZarrChecksum, + ZarrChecksumManifest, + ZarrChecksumTree, + ZarrDirectoryDigest, + compute_zarr_checksum, +) + + +def test_generate_digest() -> None: + manifest = ZarrChecksumManifest( + directories=[ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1--1", name="a/b", size=1)], + files=[ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-1--1", name="b", size=1)], + ) + assert manifest.generate_digest().digest == "2ed39fd5ae56fd4177c4eb503d163528-2--2" + + +def test_zarr_checksum_sort_order() -> None: + # The a < b in the name should take precedence over z > y in the md5 + a = ZarrChecksum(name="a", digest="z", size=3) + b = ZarrChecksum(name="b", digest="y", size=4) + assert sorted([b, a]) == [a, b] + + +def test_parse_zarr_directory_digest() -> None: + # Parse valid + ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481--38757151179") + ZarrDirectoryDigest.parse(None) + + # Ensure exception is raised + with pytest.raises(InvalidZarrChecksum): + ZarrDirectoryDigest.parse("asd") + with pytest.raises(InvalidZarrChecksum): + ZarrDirectoryDigest.parse("asd-0--0") + + +def test_pop_deepest() -> None: + tree = ZarrChecksumTree() + tree.add_leaf(Path("a/b"), size=1, digest="asd") + tree.add_leaf(Path("a/b/c"), size=1, digest="asd") + node = tree.pop_deepest() + + # Assert popped node is a/b/c, not a/b + assert str(node.path) == "a/b" + assert len(node.checksums.files) == 1 + assert len(node.checksums.directories) == 0 + assert node.checksums.files[0].name == "c" + + +def test_process_empty_tree() -> None: + tree = ZarrChecksumTree() + assert tree.process().digest == EMPTY_CHECKSUM + + +def test_process_tree() -> None: + tree = ZarrChecksumTree() + tree.add_leaf(Path("a/b"), size=1, digest="9dd4e461268c8034f5c8564e155c67a6") + tree.add_leaf(Path("c"), size=1, digest="415290769594460e2e485922904f345d") + checksum = tree.process() + + # This zarr checksum was computed against the same file structure using the previous + # zarr checksum implementation + # Assert the current implementation produces a matching checksum + assert checksum.digest == "26054e501f570a8bfa69a2bc75e7c82d-2--2" + + +def test_checksum_for_zarr_archive(zarr_archive, tmpdir): + # NOTE: This test was not in the original code base of the zarr-checksum package. + checksum = compute_zarr_checksum(zarr_archive) + + path = tmpdir.join("copy") + copytree(zarr_archive, path) + assert checksum == compute_zarr_checksum(path) + + root = zarr.open(path) + root["A"][0:10] = 0 + assert checksum != compute_zarr_checksum(path) From 22156f849062a560e1f1d2384f3ed5acf2a2a05d Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 8 May 2024 15:58:29 -0400 Subject: [PATCH 02/29] Removed left-over print statements --- polaris/dataset/_dataset.py | 3 --- tests/test_integration.py | 3 +-- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 041fd2f8..0e3bbd21 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -192,14 +192,11 @@ def _compute_checksum( # Use the sum of the row-wise hashes s.t. the hash is insensitive to the row-ordering table_hash = pd.util.hash_pandas_object(df, index=False).sum() hash_fn.update(table_hash) - print(hash_fn.hexdigest()) # If the Zarr arhive exists, we hash its contents too. if zarr_root_path is not None: zarr_hash = compute_zarr_checksum(zarr_root_path) hash_fn.update(zarr_hash.encode()) - print(zarr_root_path, zarr_hash) - print(hash_fn.hexdigest()) checksum = hash_fn.hexdigest() return checksum diff --git a/tests/test_integration.py b/tests/test_integration.py index 5d9a983b..5a6c1f69 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,6 +1,6 @@ import datamol as dm import numpy as np -from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor from polaris.evaluate import BenchmarkResults @@ -57,7 +57,6 @@ def test_single_task_benchmark_clf_loop_with_multiple_test_sets( y_prob = {} y_pred = {} for k, test_subset in test.items(): - print(k, test_subset) x_test = np.array([dm.to_fp(dm.to_mol(smi)) for smi in test_subset.inputs]) y_prob[k] = model.predict_proba(x_test)[:, :1] # for binary classification y_pred[k] = model.predict(x_test) From cefde269500a85323c4b8eb244d851002c004f56 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 8 May 2024 16:38:23 -0400 Subject: [PATCH 03/29] Minor changes to docs --- polaris/dataset/_dataset.py | 2 +- polaris/dataset/_subset.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 0e3bbd21..16cfe213 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -181,7 +181,7 @@ def _compute_checksum( 1. Is not sensitive to the ordering of the columns or rows in the table. 2. Purposefully does not include the meta-data (source, description, name, annotations). - 3. Includes a hash for the Zarr column + 3. Includes a hash for the Zarr archive. """ hash_fn = md5() diff --git a/polaris/dataset/_subset.py b/polaris/dataset/_subset.py index 2cacc4dd..448abbff 100644 --- a/polaris/dataset/_subset.py +++ b/polaris/dataset/_subset.py @@ -230,4 +230,4 @@ def __next__(self): item = self[self._pointer] self._pointer += 1 - return item # + return item From 269886a24580e8d5d0b99f1acc3c801816c98d42 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 8 May 2024 16:42:20 -0400 Subject: [PATCH 04/29] Removed unused method --- tests/test_dataset.py | 24 ------------------------ 1 file changed, 24 deletions(-) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 47a34091..06cce186 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -12,30 +12,6 @@ from polaris.utils.errors import PolarisChecksumError -def _equality_test(dataset_1, dataset_2): - """ - Utility function. - - When saving a dataset to a different location, it should be considered the same dataset - but currently the dataset checksum is used for equality and with pointer columns, - the checksum uses the file path, not the file content (which thus changes when saving). - - See also: https://github.com/polaris-hub/polaris/issues/16 - """ - if dataset_1 == dataset_2: - return True - if len(dataset_1) != len(dataset_2): - return False - if (dataset_1.table.columns != dataset_2.table.columns).all(): - return False - - for i in range(len(dataset_1)): - for col in dataset_1.table.columns: - if (dataset_1.get_data(row=i, col=col) != dataset_2.get_data(row=i, col=col)).all(): - return False - return True - - @pytest.mark.parametrize("with_caching", [True, False]) @pytest.mark.parametrize("with_slice", [True, False]) def test_load_data(tmp_path, with_slice, with_caching): From 9bc808685010ad19d2a77f05d5807d636e8699f3 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 9 May 2024 10:38:20 -0400 Subject: [PATCH 05/29] Update polaris/dataset/_dataset.py Co-authored-by: Andrew Quirke <75542075+Andrewq11@users.noreply.github.com> --- polaris/dataset/_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 16cfe213..1de22b3c 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -193,7 +193,7 @@ def _compute_checksum( table_hash = pd.util.hash_pandas_object(df, index=False).sum() hash_fn.update(table_hash) - # If the Zarr arhive exists, we hash its contents too. + # If the Zarr archive exists, we hash its contents too. if zarr_root_path is not None: zarr_hash = compute_zarr_checksum(zarr_root_path) hash_fn.update(zarr_hash.encode()) From 69aea30c4f80e0d272bcd8ed36761c53c0b74539 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 9 May 2024 10:38:27 -0400 Subject: [PATCH 06/29] Update polaris/dataset/zarr/_checksum.py Co-authored-by: Andrew Quirke <75542075+Andrewq11@users.noreply.github.com> --- polaris/dataset/zarr/_checksum.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 04be933b..d6fce4d6 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -56,7 +56,7 @@ def compute_zarr_checksum(zarr_root_path: str, fs: Optional[AbstractFileSystem] Implements an algorithm to compute the Zarr checksum. Warning: This checksum is sensitive to Zarr configuration. - This checksum is not insensitive to change in the Zarr structure. For example, if you change the chunk size, + This checksum is sensitive to change in the Zarr structure. For example, if you change the chunk size, the checksum will also change. To understand how this works, consider the following directory structure: From ad7aac66b87851a39324fee0f42d1cb9b3fa5779 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 26 Jun 2024 18:19:05 -0400 Subject: [PATCH 07/29] Lazily compute the checksum --- polaris/benchmark/_base.py | 48 ++++++++++++++----------------- polaris/dataset/_dataset.py | 34 +++++++++------------- polaris/dataset/zarr/_checksum.py | 14 ++++----- polaris/hub/client.py | 34 +++++++++++++++++----- tests/test_benchmark.py | 18 +++++------- tests/test_dataset.py | 13 +++------ tests/test_zarr_checksum.py | 4 +-- 7 files changed, 81 insertions(+), 84 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index ac3e210b..660ab240 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -8,6 +8,7 @@ from datamol.utils import fs from pydantic import ( Field, + PrivateAttr, ValidationInfo, computed_field, field_serializer, @@ -22,7 +23,7 @@ from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError +from polaris.utils.errors import InvalidBenchmarkError from polaris.utils.misc import listit from polaris.utils.types import ( AccessType, @@ -102,7 +103,6 @@ class BenchmarkSpecification(BaseArtifactModel): split: SplitType metrics: Union[str, Metric, list[Union[str, Metric]]] main_metric: Optional[Union[str, Metric]] = None - md5sum: Optional[str] = None # Additional meta-data readme: str = "" @@ -110,6 +110,9 @@ class BenchmarkSpecification(BaseArtifactModel): default_factory=dict, validate_default=True ) + # Private attributes + _md5sum: Optional[str] = PrivateAttr(None) + @field_validator("dataset") def _validate_dataset(cls, v): """ @@ -214,6 +217,12 @@ def _validate_target_types(cls, v, info: ValidationInfo): for target in target_cols: if target not in v: val = dataset[:, target] + + # Non numeric columns can be targets (e.g. prediction molecular reactions), + # but in that case we currently don't infer the target type. + if not np.issubdtype(val.dtype, np.number): + continue + # remove the nans for mutiple task dataset when the table is sparse target_type = type_of_target(val[~np.isnan(val)]) if target_type == "continuous": @@ -230,34 +239,11 @@ def _validate_target_types(cls, v, info: ValidationInfo): @classmethod def _validate_model(cls, m: "BenchmarkSpecification"): """ - If a checksum is provided, verify it matches what the checksum should be. - If no checksum is provided, make sure it is set. - Also sets a default metric if missing. + Sets a default metric if missing. """ - - # Validate checksum - checksum = m.md5sum - - expected = cls._compute_checksum( - dataset=m.dataset, - target_cols=m.target_cols, - input_cols=m.input_cols, - split=m.split, - metrics=m.metrics, - ) - - if checksum is None: - m.md5sum = expected - elif checksum != expected: - raise PolarisChecksumError( - "The dataset checksum does not match what was specified in the meta-data. " - f"{checksum} != {expected}" - ) - # Set a default main metric if not set yet if m.main_metric is None: m.main_metric = m.metrics[0] - return m @field_serializer("metrics", "main_metric") @@ -310,6 +296,16 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): checksum = hash_fn.hexdigest() return checksum + @computed_field + @property + def md5sum(self) -> Optional[str]: + """Lazily compute the checksum once needed.""" + if self._md5sum is None: + self._md5sum = self._compute_checksum( + self.dataset, self.target_cols, self.input_cols, self.split, self.metrics + ) + return self._md5sum + @computed_field @property def n_train_datapoints(self) -> int: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 1de22b3c..6b92f854 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,4 +1,5 @@ import json +import uuid from hashlib import md5 from typing import Dict, List, MutableMapping, Optional, Tuple, Union @@ -24,7 +25,7 @@ from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError +from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import AccessType, HttpUrlString, HubOwner, SupportedLicenseType # Constants @@ -73,7 +74,6 @@ class Dataset(BaseArtifactModel): table: Union[pd.DataFrame, str] default_adapters: Dict[str, Adapter] = Field(default_factory=dict) zarr_root_path: Optional[str] = None - md5sum: Optional[str] = None # Additional meta-data readme: str = "" @@ -88,6 +88,7 @@ class Dataset(BaseArtifactModel): # Private attributes _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) + _md5sum: Optional[str] = PrivateAttr(None) _client = PrivateAttr(None) # Optional[PolarisHubClient] @field_validator("table") @@ -113,10 +114,7 @@ def _validate_table(cls, v): @model_validator(mode="after") @classmethod def _validate_model(cls, m: "Dataset"): - """If a checksum is provided, verify it matches what the checksum should be. - If no checksum is provided, make sure it is set. - If no cache_dir is provided, set it to the default cache dir and make sure it exists - """ + """Verifies some dependencies between properties""" # Verify that all annotations are for columns that exist if any(k not in m.table.columns for k in m.annotations): @@ -140,22 +138,10 @@ def _validate_model(cls, m: "Dataset"): m.annotations[c] = ColumnAnnotation() m.annotations[c].dtype = m.table[c].dtype - # Verify the checksum - # NOTE (cwognum): Is it still reasonable to always verify this as the dataset size grows? - actual = m.md5sum - expected = cls._compute_checksum(m.table, m.zarr_root_path) - - if actual is None: - m.md5sum = expected - elif actual != expected: - raise PolarisChecksumError( - "The dataset md5sum does not match what was specified in the meta-data. " - f"{actual} != {expected}" - ) - # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - m.cache_dir = fs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, m.name, m.md5sum) + m.cache_dir = fs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, str(uuid.uuid4())) + fs.mkdir(m.cache_dir, exist_ok=True) return m @@ -201,6 +187,14 @@ def _compute_checksum( checksum = hash_fn.hexdigest() return checksum + @computed_field + @property + def md5sum(self) -> Optional[str]: + """Lazily compute the checksum once needed.""" + if self._md5sum is None: + self._md5sum = self._compute_checksum(self.table, self.zarr_root_path) + return self._md5sum + @property def client(self): """The Polaris Hub client used to interact with the Polaris Hub.""" diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index d6fce4d6..6a91f506 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -38,12 +38,11 @@ from functools import total_ordering from json import dumps from pathlib import Path -from typing import Optional -import datamol as dm +import fsspec +import fsspec.utils import zarr import zarr.errors -from fsspec import AbstractFileSystem from tqdm import tqdm from polaris.utils.errors import InvalidZarrChecksum @@ -51,7 +50,7 @@ ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)--([0-9]+)" -def compute_zarr_checksum(zarr_root_path: str, fs: Optional[AbstractFileSystem] = None) -> str: +def compute_zarr_checksum(zarr_root_path: str) -> str: r""" Implements an algorithm to compute the Zarr checksum. @@ -85,12 +84,9 @@ def compute_zarr_checksum(zarr_root_path: str, fs: Optional[AbstractFileSystem] This method is the biggest deviation of the original code. """ - if fs is None: - # Try guess the filesystem if it's not specified - fs = dm.utils.fs.get_mapper(zarr_root_path).fs - # Get the protocol of the path - protocol = dm.utils.fs.get_protocol(zarr_root_path, fs) + protocol = fsspec.utils.get_protocol(zarr_root_path) + fs, zarr_root_path = fsspec.url_to_fs(zarr_root_path) # For a local path, we extend the path to an absolute path # Otherwise, we assume the path is already absolute diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 36f0dcdd..691b3a52 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -31,7 +31,12 @@ from polaris.hub.settings import PolarisHubSettings from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.context import tmp_attribute_change -from polaris.utils.errors import InvalidDatasetError, PolarisHubError, PolarisUnauthorizedError +from polaris.utils.errors import ( + InvalidDatasetError, + PolarisChecksumError, + PolarisHubError, + PolarisUnauthorizedError, +) from polaris.utils.types import ( AccessType, HubOwner, @@ -354,10 +359,18 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) - if not verify_checksum: - response.pop("md5Sum", None) + dataset = Dataset(**response) + checksum = response.pop("md5Sum", None) + + if verify_checksum and checksum is not None and checksum != dataset.md5sum: + raise PolarisChecksumError( + "The dataset checksum does not match what was specified in the meta-data. " + f"{checksum} != {dataset.md5sum}" + ) + elif not verify_checksum: + dataset._md5sum = checksum - return Dataset(**response) + return dataset def open_zarr_file( self, owner: Union[str, HubOwner], name: str, path: str, mode: IOMode, as_consolidated: bool = True @@ -440,10 +453,17 @@ def get_benchmark( else MultiTaskBenchmarkSpecification ) - if not verify_checksum: - response.pop("md5Sum", None) + benchmark = benchmark_cls(**response) + checksum = response.pop("md5Sum", None) - return benchmark_cls(**response) + if verify_checksum and checksum is not None and checksum != benchmark.md5sum: + raise PolarisChecksumError( + "The benchmark checksum does not match what was specified in the meta-data. " + f"{checksum} != {benchmark.md5sum}" + ) + elif not verify_checksum: + benchmark._md5sum = checksum + return def upload_results( self, diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index d3ccca2e..33b6ab3c 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -6,7 +6,6 @@ MultiTaskBenchmarkSpecification, SingleTaskBenchmarkSpecification, ) -from polaris.utils.errors import PolarisChecksumError @pytest.mark.parametrize("is_single_task", [True, False]) @@ -140,34 +139,31 @@ def test_benchmark_checksum(is_single_task, test_single_task_benchmark, test_mul # Without any changes, same hash kwargs = obj.model_dump() - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the target columns kwargs["target_cols"] = kwargs["target_cols"][::-1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the metrics kwargs["metrics"] = kwargs["metrics"][::-1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # With a different ordering of the split kwargs["split"] = kwargs["split"][0][::-1], kwargs["split"][1] - cls(**kwargs) + assert cls(**kwargs).md5sum == original # --- Test that the checksum is NOT the same --- def _check_for_failure(_kwargs): - with pytest.raises((ValidationError, TypeError)) as error: - cls(**_kwargs) - assert error.error_count() == 1 # noqa - assert isinstance(error.errors()[0], PolarisChecksumError) # noqa + assert cls(**_kwargs).md5sum != _kwargs["md5sum"] # Split kwargs = obj.model_dump() - kwargs["split"] = kwargs["split"][0][1:] + [-1], kwargs["split"][1] + kwargs["split"] = kwargs["split"][0][1:], kwargs["split"][1] _check_for_failure(kwargs) kwargs = obj.model_dump() - kwargs["split"] = kwargs["split"][0], kwargs["split"][1][1:] + [-1] + kwargs["split"] = kwargs["split"][0], kwargs["split"][1][1:] _check_for_failure(kwargs) # Metrics diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 06cce186..4493b90c 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -5,11 +5,9 @@ import pytest import zarr from datamol.utils import fs -from pydantic import ValidationError from polaris.dataset import Dataset, Subset, create_dataset_from_file from polaris.loader import load_dataset -from polaris.utils.errors import PolarisChecksumError @pytest.mark.parametrize("with_caching", [True, False]) @@ -56,24 +54,21 @@ def test_dataset_checksum(test_dataset): # Without any changes, same hash kwargs = test_dataset.model_dump() - Dataset(**kwargs) + assert Dataset(**kwargs).md5sum == original # With unimportant changes, same hash kwargs["name"] = "changed" kwargs["description"] = "changed" kwargs["source"] = "https://changed.com" - Dataset(**kwargs) + assert Dataset(**kwargs).md5sum == original # Check sensitivity to the row and column ordering kwargs["table"] = kwargs["table"].iloc[::-1] kwargs["table"] = kwargs["table"][kwargs["table"].columns[::-1]] - Dataset(**kwargs) + assert Dataset(**kwargs).md5sum == original def _check_for_failure(_kwargs): - with pytest.raises(ValidationError) as error: - Dataset(**_kwargs) - assert error.error_count() == 1 # noqa - assert isinstance(error.errors()[0], PolarisChecksumError) # noqa + assert Dataset(**_kwargs).md5sum != _kwargs["md5sum"] # Without any changes, but different hash kwargs["md5sum"] = "invalid" diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index 7926d6ab..59964df2 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -110,8 +110,8 @@ def test_checksum_for_zarr_archive(zarr_archive, tmpdir): path = tmpdir.join("copy") copytree(zarr_archive, path) - assert checksum == compute_zarr_checksum(path) + assert checksum == compute_zarr_checksum(str(path)) root = zarr.open(path) root["A"][0:10] = 0 - assert checksum != compute_zarr_checksum(path) + assert checksum != compute_zarr_checksum(str(path)) From 35246cec55f3fc643af32f4a6b3cbb9ada9fa84a Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 26 Jun 2024 19:00:34 -0400 Subject: [PATCH 08/29] Save the checksum per file --- polaris/dataset/_dataset.py | 21 +++++++++++++++++---- polaris/dataset/zarr/_checksum.py | 10 ++++++++-- polaris/hub/client.py | 1 + tests/conftest.py | 4 ++-- tests/test_benchmark.py | 4 ++++ tests/test_dataset.py | 4 ++++ tests/test_zarr_checksum.py | 15 ++++++++++++--- 7 files changed, 48 insertions(+), 11 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 6b92f854..d7c775e6 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -89,6 +89,7 @@ class Dataset(BaseArtifactModel): _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) _md5sum: Optional[str] = PrivateAttr(None) + _leaf_to_md5sum: Optional[Dict[str, str]] = PrivateAttr(None) _client = PrivateAttr(None) # Optional[PolarisHubClient] @field_validator("table") @@ -180,21 +181,33 @@ def _compute_checksum( hash_fn.update(table_hash) # If the Zarr archive exists, we hash its contents too. + leaf_to_md5sum = None if zarr_root_path is not None: - zarr_hash = compute_zarr_checksum(zarr_root_path) + zarr_hash, leaf_to_md5sum = compute_zarr_checksum(zarr_root_path) hash_fn.update(zarr_hash.encode()) checksum = hash_fn.hexdigest() - return checksum + return checksum, leaf_to_md5sum @computed_field @property - def md5sum(self) -> Optional[str]: + def md5sum(self) -> str: """Lazily compute the checksum once needed.""" if self._md5sum is None: - self._md5sum = self._compute_checksum(self.table, self.zarr_root_path) + self._md5sum, self._leaf_to_md5sum = self._compute_checksum(self.table, self.zarr_root_path) return self._md5sum + @computed_field + @property + def leaf_to_md5sum(self) -> Optional[Dict[str, str]]: + """ + For Zarr archives, the mapping from all files to their checksum is used by the Hub + to verify data integrity on upload. + """ + if self._leaf_to_md5sum is None and self._md5sum is None: + self._md5sum, self._leaf_to_md5sum = self._compute_checksum(self.table, self.zarr_root_path) + return self._leaf_to_md5sum + @property def client(self): """The Polaris Hub client used to interact with the Polaris Hub.""" diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 6a91f506..a5b9d7df 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -38,6 +38,7 @@ from functools import total_ordering from json import dumps from pathlib import Path +from typing import Dict, Tuple import fsspec import fsspec.utils @@ -50,7 +51,7 @@ ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)--([0-9]+)" -def compute_zarr_checksum(zarr_root_path: str) -> str: +def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: r""" Implements an algorithm to compute the Zarr checksum. @@ -103,6 +104,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> str: # Find all files in the root leaves = fs.find(zarr_root_path, detail=True) + leaf_to_md5sum = {} for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): path = file["name"] @@ -127,8 +129,12 @@ def compute_zarr_checksum(zarr_root_path: str) -> str: digest=digest, ) + # We persist the checksums for leaf nodes separately, + # because this is what the Hub needs to verify data integrity. + leaf_to_md5sum[str(relpath)] = digest + # Compute digest - return tree.process().digest + return tree.process().digest, leaf_to_md5sum # Pydantic models aren't used for performance reasons diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 691b3a52..19371992 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -369,6 +369,7 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b ) elif not verify_checksum: dataset._md5sum = checksum + dataset._leaf_to_md5sum = response.get("leafToMd5Sum", None) return dataset diff --git a/tests/conftest.py b/tests/conftest.py index aa02c3f7..b034d464 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -81,8 +81,8 @@ def test_dataset(test_data, test_org_owner): def zarr_archive(tmp_path): tmp_path = fs.join(tmp_path, "data.zarr") root = zarr.open(tmp_path, mode="w") - root.array("A", data=np.random.random((100, 2048))) - root.array("B", data=np.random.random((100, 2048))) + root.array("A", data=np.random.random((100, 2048)), chunks=(1, None)) + root.array("B", data=np.random.random((100, 2048)), chunks=(1, None)) zarr.consolidate_metadata(root.store) return tmp_path diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 33b6ab3c..45e6814a 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -132,6 +132,10 @@ def test_benchmark_checksum(is_single_task, test_single_task_benchmark, test_mul obj = test_single_task_benchmark if is_single_task else test_multi_task_benchmark cls = SingleTaskBenchmarkSpecification if is_single_task else MultiTaskBenchmarkSpecification + # Make sure the `md5sum` is part of the model dump even if not initiated yet. + # This is important for uploads to the Hub. + assert obj._md5sum is None and "md5sum" in obj.model_dump() + original = obj.md5sum assert original is not None diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4493b90c..3e1dccf8 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -49,6 +49,10 @@ def test_load_data(tmp_path, with_slice, with_caching): def test_dataset_checksum(test_dataset): """Test whether the checksum is a good indicator of whether the dataset has changed in a meaningful way.""" + # Make sure the `md5sum` is part of the model dump even if not initiated yet. + # This is important for uploads to the Hub. + assert test_dataset._md5sum is None and "md5sum" in test_dataset.model_dump() + original = test_dataset.md5sum assert original is not None diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index 59964df2..b987bbfe 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -106,12 +106,21 @@ def test_process_tree() -> None: def test_checksum_for_zarr_archive(zarr_archive, tmpdir): # NOTE: This test was not in the original code base of the zarr-checksum package. - checksum = compute_zarr_checksum(zarr_archive) + checksum, _ = compute_zarr_checksum(zarr_archive) path = tmpdir.join("copy") copytree(zarr_archive, path) - assert checksum == compute_zarr_checksum(str(path)) + assert checksum == compute_zarr_checksum(str(path))[0] root = zarr.open(path) root["A"][0:10] = 0 - assert checksum != compute_zarr_checksum(str(path)) + assert checksum != compute_zarr_checksum(str(path))[0] + + +def test_zarr_leaf_to_checksum(zarr_archive): + _, leaf_to_checksum = compute_zarr_checksum(zarr_archive) + root = zarr.open(zarr_archive) + + # Check the basic structure - Each key corresponds to a file in the zarr archive + assert len(leaf_to_checksum) == len(root.store) + assert all(k in root.store for k in leaf_to_checksum.keys()) From b2aae93f4eaa2c78b108e82366adcfd379f52907 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 27 Jun 2024 13:59:59 -0400 Subject: [PATCH 09/29] Improved docs because I kept forgetting how it works --- polaris/dataset/zarr/_checksum.py | 210 ++++++++++++++++++++---------- tests/test_zarr_checksum.py | 11 +- 2 files changed, 144 insertions(+), 77 deletions(-) diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index a5b9d7df..c86f635a 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -48,7 +48,7 @@ from polaris.utils.errors import InvalidZarrChecksum -ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)--([0-9]+)" +ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: @@ -102,7 +102,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # Generate the checksum tree = ZarrChecksumTree() - # Find all files in the root + # Find all files below the root leaves = fs.find(zarr_root_path, detail=True) leaf_to_md5sum = {} @@ -137,43 +137,77 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: return tree.process().digest, leaf_to_md5sum -# Pydantic models aren't used for performance reasons -@dataclass -class ZarrChecksumNode: - """Represents the aggregation of zarr files at a specific path in the tree.""" +# ================================ +# Overview of the data structures +# ================================ - path: Path - checksums: "ZarrChecksumManifest" +# NOTE (cwognum): I kept forgetting how this works, so I'm writing it down +# - The ZarrChecksumTree is a binary tree (heap queue). It determines the order in which to process the nodes. +# - The ZarrChecksumNode is a node in the ZarrChecksumTree queue. It represents a directory in the Zarr archive and +# stores a manifest with all the data needed to compute the checksum for that node. +# - The ZarrChecksumManifest is a collection of checksums for all direct (non-recursive) children of a directory. +# - The ZarrChecksum is the data used to compute the checksum for a file or directory in a Zarr Archive. +# This is the object that the ZarrChecksumManifest stores a collection of. +# - A ZarrDirectoryDigest is the result of processing a directory. Once completed, +# it is added to the ZarrChecksumManifest of its parent as part of a ZarrChecksum. - def __lt__(self, other: "ZarrChecksumNode") -> bool: - return str(self.path) < str(other.path) +# NOTE (cwognum): As a first impression, it seems there is some redundancy in the data structures. +# My feeling is that we could reduce the redundancy to simplify things and improve maintainability. +# However, for the time being, let's stick close to the original code. +# ================================ + +# Pydantic models aren't used for performance reasons class ZarrChecksumTree: - """A tree that represents the checksummed files in a zarr.""" + """ + The ZarrChecksumTree is a tree structure that maintains the state of the checksum algorithm. + + Initialized with a set of leafs (i.e. files), the nodes in this tree correspond to all directories + that are above those leafs and below the Zarr Root. + + The tree then implements the logic for retrieving the next node (i.e. directory) to process, + and for computing the checksum for that node based on its children. + Once it reaches the root, it has computed the checksum for the entire Zarr archive. + """ def __init__(self) -> None: + # Queue to prioritize the next node to process self._heap: list[tuple[int, ZarrChecksumNode]] = [] + + # Map of (relative) paths to nodes. self._path_map: dict[Path, ZarrChecksumNode] = {} @property def empty(self) -> bool: + """Check if the tree is empty.""" + # This is used as an exit condition in the process() method return len(self._heap) == 0 def _add_path(self, key: Path) -> None: - node = ZarrChecksumNode(path=key, checksums=ZarrChecksumManifest()) + """Adds a new entry to the heap queue for which we need to compute the checksum.""" - # Add link to node + # Create a new node + # A node represents a file or directory. + # A node refers to a node in the heap queue (i.e. binary tree) + # The structure of the heap is thus _not_ the same as the structure of the file system! + node = ZarrChecksumNode(path=key, checksums=ZarrChecksumManifest()) self._path_map[key] = node - # Add node to heap with length (negated to representa max heap) + # Add node to heap with length (negated to represent a max heap) + # We use the length of the parents (relative to the Zarr root) to structure the heap. + # The node with the longest path is the deepest node in the tree. + # This node will be prioritized for processing next. length = len(key.parents) heapq.heappush(self._heap, (-1 * length, node)) - def _get_path(self, key: Path) -> ZarrChecksumNode: + def _get_path(self, key: Path) -> "ZarrChecksumNode": + """ + If an entry for this path already exists, return it. + Otherwise create a new one and return that. + """ if key not in self._path_map: self._add_path(key) - return self._path_map[key] def add_leaf(self, path: Path, size: int, digest: str) -> None: @@ -181,7 +215,7 @@ def add_leaf(self, path: Path, size: int, digest: str) -> None: parent_node = self._get_path(path.parent) parent_node.checksums.files.append(ZarrChecksum(name=path.name, size=size, digest=digest)) - def add_node(self, path: Path, size: int, digest: str) -> None: + def add_node(self, path: Path, size: int, digest: str, count: int) -> None: """Add an internal node to the tree.""" parent_node = self._get_path(path.parent) parent_node.checksums.directories.append( @@ -189,22 +223,31 @@ def add_node(self, path: Path, size: int, digest: str) -> None: name=path.name, size=size, digest=digest, + count=count, ) ) - def pop_deepest(self) -> ZarrChecksumNode: - """Find the deepest node in the tree, and return it.""" + def pop_deepest(self) -> "ZarrChecksumNode": + """ + Returns the node with the highest priority for processing next. + + Returns (one of the) node(s) with the most parent directories + (i.e. the deepest directory in the file system) + """ _, node = heapq.heappop(self._heap) del self._path_map[node.path] - return node def process(self) -> "ZarrDirectoryDigest": """Process the tree, returning the resulting top level digest.""" - # Begin with empty root node, so if no files are present, the empty checksum is returned + + # Begin with empty root node, so that if no files are present, the empty checksum is returned node = ZarrChecksumNode(path=Path("."), checksums=ZarrChecksumManifest()) + while not self.empty: - # Pop the deepest directory available + # Get the next directory to process + # Priority is based on the number of parents a directory has + # In other word, the depth of the directory in the file system. node = self.pop_deepest() # If we have reached the root node, then we're done. @@ -217,6 +260,7 @@ def process(self) -> "ZarrDirectoryDigest": path=node.path, size=directory_digest.size, digest=directory_digest.digest, + count=directory_digest.count, ) # Return digest @@ -224,80 +268,50 @@ def process(self) -> "ZarrDirectoryDigest": @dataclass -class ZarrDirectoryDigest: - """The data that can be serialized to / deserialized from a checksum string.""" - - md5: str - count: int - size: int - - @classmethod - def parse(cls, checksum: str | None) -> "ZarrDirectoryDigest": - if checksum is None: - return cls.parse(EMPTY_CHECKSUM) - - match = re.match(ZARR_DIGEST_PATTERN, checksum) - if match is None: - raise InvalidZarrChecksum() - - md5, count, size = match.groups() - return cls(md5=md5, count=int(count), size=int(size)) - - def __str__(self) -> str: - return self.digest - - @property - def digest(self) -> str: - return f"{self.md5}-{self.count}--{self.size}" - - -@total_ordering -@dataclass -class ZarrChecksum: +class ZarrChecksumNode: """ - A checksum for a single file/directory in a zarr file. + A node in the ZarrChecksumTree. - Every file and directory in a zarr archive has a name, digest, and size. - Leaf nodes are created by providing an md5 digest. - Internal nodes (directories) have a digest field that is a zarr directory digest + This node represents a file or directory in the Zarr archive, + but "node" here refers to a node in the heap queue (i.e. binary tree). + The structure of the heap is thus _not_ the same as the structure of the file system! - This class is serialized to JSON, and as such, key order should not be modified. + The node stores a manifest of checksums for all files and directories below it. """ - digest: str - name: str - size: int + path: Path + checksums: "ZarrChecksumManifest" - # To make this class sortable - def __lt__(self, other: "ZarrChecksum") -> bool: - return self.name < other.name + def __lt__(self, other: "ZarrChecksumNode") -> bool: + return str(self.path) < str(other.path) @dataclass class ZarrChecksumManifest: """ - A set of file and directory checksums. + For a directory in the Zarr archive (i.e. a node in the heap queue), + we maintain a manifest of the checksums for all files and directories + below that directory. - This is the data hashed to calculate the checksum of a directory. + This data is then used to calculate the checksum of a directory. """ - directories: list[ZarrChecksum] = field(default_factory=list) - files: list[ZarrChecksum] = field(default_factory=list) + directories: list["ZarrChecksum"] = field(default_factory=list) + files: list["ZarrChecksum"] = field(default_factory=list) @property def is_empty(self) -> bool: return not (self.files or self.directories) - def generate_digest(self) -> ZarrDirectoryDigest: + def generate_digest(self) -> "ZarrDirectoryDigest": """Generate an aggregated digest for the provided files/directories.""" - # Ensure sorted first + + # Sort everything to ensure the checksum is deterministic self.files.sort() self.directories.sort() # Aggregate total file count - count = len(self.files) + sum( - ZarrDirectoryDigest.parse(checksum.digest).count for checksum in self.directories - ) + count = len(self.files) + sum(checksum.count for checksum in self.directories) # Aggregate total size size = sum(file.size for file in self.files) + sum(directory.size for directory in self.directories) @@ -312,5 +326,57 @@ def generate_digest(self) -> ZarrDirectoryDigest: return ZarrDirectoryDigest(md5=md5, count=count, size=size) +@total_ordering +@dataclass +class ZarrChecksum: + """ + The data used to compute the checksum for a file or directory in a Zarr Archive. + + This class is serialized to JSON, and as such, key order should not be modified. + """ + + digest: str + name: str + size: int + count: int = 0 + + # To make this class sortable + def __lt__(self, other: "ZarrChecksum") -> bool: + return self.name < other.name + + +@dataclass +class ZarrDirectoryDigest: + """ + The digest for a directory in a Zarr Archive. + + The digest is a string representation that serves as a checksum for the directory. + This is a utility class to (de)serialize that string. + """ + + md5: str + count: int + size: int + + @classmethod + def parse(cls, checksum: str | None) -> "ZarrDirectoryDigest": + if checksum is None: + return cls.parse(EMPTY_CHECKSUM) + + match = re.match(ZARR_DIGEST_PATTERN, checksum) + if match is None: + raise InvalidZarrChecksum() + + md5, count, size = match.groups() + return cls(md5=md5, count=int(count), size=int(size)) + + def __str__(self) -> str: + return self.digest + + @property + def digest(self) -> str: + return f"{self.md5}-{self.count}-{self.size}" + + # The "null" zarr checksum EMPTY_CHECKSUM = ZarrChecksumManifest().generate_digest().digest diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index b987bbfe..b3242b9c 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -49,10 +49,10 @@ def test_generate_digest() -> None: manifest = ZarrChecksumManifest( - directories=[ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1--1", name="a/b", size=1)], - files=[ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-1--1", name="b", size=1)], + directories=[ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1-1", name="a/b", size=1)], + files=[ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-0-1", name="b", size=1)], ) - assert manifest.generate_digest().digest == "2ed39fd5ae56fd4177c4eb503d163528-2--2" + assert manifest.generate_digest().digest == "9c5294e46908cf397cb7ef53ffc12efc-1-2" def test_zarr_checksum_sort_order() -> None: @@ -64,7 +64,7 @@ def test_zarr_checksum_sort_order() -> None: def test_parse_zarr_directory_digest() -> None: # Parse valid - ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481--38757151179") + ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481-38757151179") ZarrDirectoryDigest.parse(None) # Ensure exception is raised @@ -101,7 +101,7 @@ def test_process_tree() -> None: # This zarr checksum was computed against the same file structure using the previous # zarr checksum implementation # Assert the current implementation produces a matching checksum - assert checksum.digest == "26054e501f570a8bfa69a2bc75e7c82d-2--2" + assert checksum.digest == "e53fcb7b5c36b2f4647fbf826a44bdc9-2-2" def test_checksum_for_zarr_archive(zarr_archive, tmpdir): @@ -118,6 +118,7 @@ def test_checksum_for_zarr_archive(zarr_archive, tmpdir): def test_zarr_leaf_to_checksum(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. _, leaf_to_checksum = compute_zarr_checksum(zarr_archive) root = zarr.open(zarr_archive) From a7d6aef11bd0b2eaa316497f634599858014d372 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 28 Jun 2024 12:33:04 -0400 Subject: [PATCH 10/29] Only support running the checksum algorithm locally --- polaris/dataset/zarr/_checksum.py | 24 ++++++++++++------- tests/test_zarr_checksum.py | 38 ++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index c86f635a..0bddcffd 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -87,14 +87,21 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # Get the protocol of the path protocol = fsspec.utils.get_protocol(zarr_root_path) - fs, zarr_root_path = fsspec.url_to_fs(zarr_root_path) - # For a local path, we extend the path to an absolute path - # Otherwise, we assume the path is already absolute - if protocol == "file": - zarr_root_path = os.path.expandvars(zarr_root_path) - zarr_root_path = os.path.expanduser(zarr_root_path) - zarr_root_path = os.path.abspath(zarr_root_path) + # NOTE (cwognum): The original Zarr Checksum implementation also seem to work for S3, + # but I'm not sure yet how + if protocol != "file": + raise RuntimeError( + "You can only compute the checksum for a local Zarr archive. " + "You can cache a dataset to your local machine with `dataset.cache()`." + ) + + # Normalize the path + zarr_root_path = os.path.expandvars(zarr_root_path) + zarr_root_path = os.path.expanduser(zarr_root_path) + zarr_root_path = os.path.abspath(zarr_root_path) + + fs, zarr_root_path = fsspec.url_to_fs(zarr_root_path) # Make sure the path exists and is a Zarr archive zarr.open_group(zarr_root_path, mode="r") @@ -122,7 +129,8 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: md5sum.update(chunk) digest = md5sum.hexdigest() - # Yield file + # Add a leaf to the tree + # (This actually adds the file's checksum to the parent directory's manifest) tree.add_leaf( path=relpath, size=size, diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index b3242b9c..42dc3916 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -30,8 +30,10 @@ limitations under the License. """ +import os +import uuid from pathlib import Path -from shutil import copytree +from shutil import copytree, rmtree import pytest import zarr @@ -125,3 +127,37 @@ def test_zarr_leaf_to_checksum(zarr_archive): # Check the basic structure - Each key corresponds to a file in the zarr archive assert len(leaf_to_checksum) == len(root.store) assert all(k in root.store for k in leaf_to_checksum.keys()) + + +def test_zarr_checksum_fails_for_remote_storage(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. + with pytest.raises(RuntimeError): + compute_zarr_checksum("s3://bucket/data.zarr") + with pytest.raises(RuntimeError): + compute_zarr_checksum("gs://bucket/data.zarr") + + +def test_zarr_checksum_with_path_normalization(zarr_archive): + # NOTE: This test was not in the original code base of the zarr-checksum package. + + baseline = compute_zarr_checksum(zarr_archive)[0] + rootdir = os.path.dirname(zarr_archive) + + # Test a relative path + copytree(zarr_archive, os.path.join(rootdir, "relative", "data.zarr")) + compute_zarr_checksum(f"{zarr_archive}/../relative/data.zarr")[0] == baseline + + # Test with variables + rng_id = str(uuid.uuid4()) + os.environ["TMP_TEST_DIR"] = rng_id + copytree(zarr_archive, os.path.join(rootdir, "vars", rng_id)) + compute_zarr_checksum(f"{rootdir}/vars/${{TMP_TEST_DIR}}")[0] == baseline # Format ${...} + compute_zarr_checksum(f"{rootdir}/vars/$TMP_TEST_DIR")[0] == baseline # Format $... + + # And with the user abbreviation + try: + path = os.path.expanduser("~/data.zarr") + copytree(zarr_archive, path) + compute_zarr_checksum("~/data.zarr")[0] == baseline + finally: + rmtree(path) From 10a9b531e1b4115063208853caa37de06839a21b Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 28 Jun 2024 13:27:54 -0400 Subject: [PATCH 11/29] Add a verify_checksum method and use it by default when caching a dataset to local --- polaris/dataset/_dataset.py | 81 ++++++++++++++++++++++--------- polaris/dataset/zarr/_checksum.py | 6 ++- polaris/hub/client.py | 9 ++-- 3 files changed, 65 insertions(+), 31 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index d7c775e6..fc30a78e 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -4,10 +4,11 @@ from typing import Dict, List, MutableMapping, Optional, Tuple, Union import fsspec +import fsspec.utils import numpy as np import pandas as pd import zarr -from datamol.utils import fs +from datamol.utils import fs as dmfs from loguru import logger from pydantic import ( Field, @@ -25,7 +26,7 @@ from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidDatasetError +from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError from polaris.utils.types import AccessType, HttpUrlString, HubOwner, SupportedLicenseType # Constants @@ -101,7 +102,7 @@ def _validate_table(cls, v): """ # Load from path if not a dataframe if not isinstance(v, pd.DataFrame): - if not fs.is_file(v) or fs.get_extension(v) not in _SUPPORTED_TABLE_EXTENSIONS: + if not dmfs.is_file(v) or dmfs.get_extension(v) not in _SUPPORTED_TABLE_EXTENSIONS: raise InvalidDatasetError(f"{v} is not a valid DataFrame or .parquet path.") v = pd.read_parquet(v) # Check if there are any duplicate columns @@ -141,9 +142,9 @@ def _validate_model(cls, m: "Dataset"): # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - m.cache_dir = fs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, str(uuid.uuid4())) + m.cache_dir = dmfs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, str(uuid.uuid4())) - fs.mkdir(m.cache_dir, exist_ok=True) + dmfs.mkdir(m.cache_dir, exist_ok=True) return m @@ -189,6 +190,34 @@ def _compute_checksum( checksum = hash_fn.hexdigest() return checksum, leaf_to_md5sum + def verify_checksum(self, md5sum: Optional[str] = None): + """ + Recomputes the checksum and verifies whether it matches the stored checksum. + + Warning: Slow operation + This operation can be slow for large datasets. + + Info: Only works for locally stored datasets + The checksum verification only works for datasets that are stored locally in its entirety. + We don't have to verify the checksum for datasets stored on the Hub, as the Hub will do this on upload. + And if you're streaming the data from the Hub, we will check the checksum of each chunk on download. + """ + if md5sum is None: + md5sum = self._md5sum + if md5sum is None: + raise RuntimeError( + "No checksum to verify against. Specify either the md5sum parameter or " + "store the checksum in the dataset._md5sum attribute." + ) + + # Temporarily reset + # Calling self.md5sum will recompute the checksum and set it again + self._md5sum = None + if self.md5sum != md5sum: + raise PolarisChecksumError( + f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" + ) + @computed_field @property def md5sum(self) -> str: @@ -219,6 +248,11 @@ def client(self): self._client = PolarisHubClient() return self._client + @property + def uses_zarr(self) -> str: + """Whether any of the data in this dataset is stored in a Zarr Archive.""" + return self.zarr_root_path is not None + @property def zarr_data(self): """Get the Zarr data. @@ -254,7 +288,7 @@ def zarr_root(self): # We open the archive in read-only mode if it is saved on the Hub saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) - saved_remote = saved_on_hub or not fs.is_local_path(self.zarr_root_path) + saved_remote = saved_on_hub or not dmfs.is_local_path(self.zarr_root_path) if saved_remote: logger.warning( @@ -401,27 +435,22 @@ def to_json(self, destination: str) -> str: Returns: The path to the JSON file. """ - fs.mkdir(destination, exist_ok=True) - table_path = fs.join(destination, "table.parquet") - dataset_path = fs.join(destination, "dataset.json") - zarr_archive = fs.join(destination, "data.zarr") + dmfs.mkdir(destination, exist_ok=True) + table_path = dmfs.join(destination, "table.parquet") + dataset_path = dmfs.join(destination, "dataset.json") + new_zarr_root_path = dmfs.join(destination, "data.zarr") # Lu: Avoid serilizing and sending None to hub app. serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True) serialized["table"] = table_path # Copy over Zarr data to the destination - if self.zarr_root is not None: - dest = zarr.open(zarr_archive, "w") - zarr.copy_all(source=self.zarr_root, dest=dest) - - # Copy the .zmetadata file - # To track discussions on whether this should be done by copy_all() - # see https://github.com/zarr-developers/zarr-python/issues/1731 - zmetadata_content = self.zarr_root.store.store[".zmetadata"] - dest.store[".zmetadata"] = zmetadata_content - - serialized["zarr_root_path"] = zarr_archive + if self.uses_zarr: + # Zarr has the `copy_all` function, but this does not copy the .zmetadata file + # and creates .zattrs files even if there aren't any user attributes. This + # messes with our checksums. So we copy the files manually. + dmfs.copy_dir(self.zarr_root_path, new_zarr_root_path) + serialized["zarr_root_path"] = new_zarr_root_path self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: @@ -429,12 +458,13 @@ def to_json(self, destination: str) -> str: return dataset_path - def cache(self, cache_dir: Optional[str] = None) -> str: + def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) -> str: """Caches the dataset by downloading all additional data for pointer columns to a local directory. Args: cache_dir: The directory to cache the data to. If not provided, this will fall back to the `Dataset.cache_dir` attribute + verify_checksum: Whether to verify the checksum of the dataset after caching. Returns: The path to the cache directory. @@ -445,10 +475,13 @@ def cache(self, cache_dir: Optional[str] = None) -> str: self.to_json(self.cache_dir) - if self.zarr_root_path is not None: - self.zarr_root_path = fs.join(self.cache_dir, "data.zarr") + if self.uses_zarr: + self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") self._zarr_root = None + if verify_checksum and self._md5sum is not None: + self.verify_checksum(md5sum=self._md5sum) + return self.cache_dir def size(self): diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 0bddcffd..e20a1ea7 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -88,8 +88,10 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # Get the protocol of the path protocol = fsspec.utils.get_protocol(zarr_root_path) - # NOTE (cwognum): The original Zarr Checksum implementation also seem to work for S3, - # but I'm not sure yet how + # We only support computing checksum for local datasets. + # NOTE (cwognum): We don't have to verify the checksum for datasets stored on the Hub, + # as the Hub will do this on upload. And if you're streaming the data from the Hub, + # we will check the checksum of each chunk on download. if protocol != "file": raise RuntimeError( "You can only compute the checksum for a local Zarr archive. " diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 19371992..9ca356b2 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -362,11 +362,10 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b dataset = Dataset(**response) checksum = response.pop("md5Sum", None) - if verify_checksum and checksum is not None and checksum != dataset.md5sum: - raise PolarisChecksumError( - "The dataset checksum does not match what was specified in the meta-data. " - f"{checksum} != {dataset.md5sum}" - ) + if verify_checksum and checksum is not None: + if dataset.uses_zarr: + logger.info("Skipping checksum verification, because the dataset is stored remotely.") + dataset.verify_checksum(md5sum=checksum) elif not verify_checksum: dataset._md5sum = checksum dataset._leaf_to_md5sum = response.get("leafToMd5Sum", None) From eaaa961ac57be7ed5ad3f8c75934cfada0b4b71b Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 28 Jun 2024 13:46:49 -0400 Subject: [PATCH 12/29] Added serialization to the checksum manifest on the client --- polaris/dataset/_dataset.py | 18 ++++---- polaris/dataset/zarr/__init__.py | 4 +- polaris/dataset/zarr/_checksum.py | 68 +++++++++++++++++++------------ polaris/hub/client.py | 2 +- tests/test_zarr_checksum.py | 32 +++++++-------- 5 files changed, 69 insertions(+), 55 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index fc30a78e..762ccb39 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -22,7 +22,7 @@ from polaris._artifact import BaseArtifactModel from polaris.dataset._adapters import Adapter from polaris.dataset._column import ColumnAnnotation -from polaris.dataset.zarr import MemoryMappedDirectoryStore, compute_zarr_checksum +from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html @@ -90,7 +90,7 @@ class Dataset(BaseArtifactModel): _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) _md5sum: Optional[str] = PrivateAttr(None) - _leaf_to_md5sum: Optional[Dict[str, str]] = PrivateAttr(None) + _zarr_md5sum_manifest: Optional[Dict[str, ZarrFileChecksum]] = PrivateAttr(None) _client = PrivateAttr(None) # Optional[PolarisHubClient] @field_validator("table") @@ -182,13 +182,13 @@ def _compute_checksum( hash_fn.update(table_hash) # If the Zarr archive exists, we hash its contents too. - leaf_to_md5sum = None + zarr_md5sum_manifest = None if zarr_root_path is not None: - zarr_hash, leaf_to_md5sum = compute_zarr_checksum(zarr_root_path) + zarr_hash, zarr_md5sum_manifest = compute_zarr_checksum(zarr_root_path) hash_fn.update(zarr_hash.encode()) checksum = hash_fn.hexdigest() - return checksum, leaf_to_md5sum + return checksum, zarr_md5sum_manifest def verify_checksum(self, md5sum: Optional[str] = None): """ @@ -228,14 +228,14 @@ def md5sum(self) -> str: @computed_field @property - def leaf_to_md5sum(self) -> Optional[Dict[str, str]]: + def zarr_md5sum_manifest(self) -> Optional[Dict[str, ZarrFileChecksum]]: """ For Zarr archives, the mapping from all files to their checksum is used by the Hub to verify data integrity on upload. """ - if self._leaf_to_md5sum is None and self._md5sum is None: - self._md5sum, self._leaf_to_md5sum = self._compute_checksum(self.table, self.zarr_root_path) - return self._leaf_to_md5sum + if self._zarr_md5sum_manifest is None and self._md5sum is None: + self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) + return self._zarr_md5sum_manifest @property def client(self): diff --git a/polaris/dataset/zarr/__init__.py b/polaris/dataset/zarr/__init__.py index 78579ce4..57f500ed 100644 --- a/polaris/dataset/zarr/__init__.py +++ b/polaris/dataset/zarr/__init__.py @@ -1,4 +1,4 @@ -from ._checksum import compute_zarr_checksum +from ._checksum import ZarrFileChecksum, compute_zarr_checksum from ._memmap import MemoryMappedDirectoryStore -__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum"] +__all__ = ["MemoryMappedDirectoryStore", "compute_zarr_checksum", "ZarrFileChecksum"] diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index e20a1ea7..620def51 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -44,6 +44,7 @@ import fsspec.utils import zarr import zarr.errors +from pydantic import BaseModel from tqdm import tqdm from polaris.utils.errors import InvalidZarrChecksum @@ -109,11 +110,11 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: zarr.open_group(zarr_root_path, mode="r") # Generate the checksum - tree = ZarrChecksumTree() + tree = _ZarrChecksumTree() # Find all files below the root leaves = fs.find(zarr_root_path, detail=True) - leaf_to_md5sum = {} + zarr_md5sum_manifest = {} for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): path = file["name"] @@ -141,10 +142,23 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # We persist the checksums for leaf nodes separately, # because this is what the Hub needs to verify data integrity. - leaf_to_md5sum[str(relpath)] = digest + zarr_md5sum_manifest[str(relpath)] = ZarrFileChecksum(md5sum=digest, size=size) # Compute digest - return tree.process().digest, leaf_to_md5sum + return tree.process().digest, zarr_md5sum_manifest + + +class ZarrFileChecksum(BaseModel): + """ + This data is sent to the Hub to verify the integrity of the Zarr archive on upload. + + Attributes: + md5sum: The md5sum of the file. + size: The size of the file in bytes. + """ + + md5sum: str + size: int # ================================ @@ -169,7 +183,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # Pydantic models aren't used for performance reasons -class ZarrChecksumTree: +class _ZarrChecksumTree: """ The ZarrChecksumTree is a tree structure that maintains the state of the checksum algorithm. @@ -183,10 +197,10 @@ class ZarrChecksumTree: def __init__(self) -> None: # Queue to prioritize the next node to process - self._heap: list[tuple[int, ZarrChecksumNode]] = [] + self._heap: list[tuple[int, _ZarrChecksumNode]] = [] # Map of (relative) paths to nodes. - self._path_map: dict[Path, ZarrChecksumNode] = {} + self._path_map: dict[Path, _ZarrChecksumNode] = {} @property def empty(self) -> bool: @@ -201,7 +215,7 @@ def _add_path(self, key: Path) -> None: # A node represents a file or directory. # A node refers to a node in the heap queue (i.e. binary tree) # The structure of the heap is thus _not_ the same as the structure of the file system! - node = ZarrChecksumNode(path=key, checksums=ZarrChecksumManifest()) + node = _ZarrChecksumNode(path=key, checksums=_ZarrChecksumManifest()) self._path_map[key] = node # Add node to heap with length (negated to represent a max heap) @@ -211,7 +225,7 @@ def _add_path(self, key: Path) -> None: length = len(key.parents) heapq.heappush(self._heap, (-1 * length, node)) - def _get_path(self, key: Path) -> "ZarrChecksumNode": + def _get_path(self, key: Path) -> "_ZarrChecksumNode": """ If an entry for this path already exists, return it. Otherwise create a new one and return that. @@ -223,13 +237,13 @@ def _get_path(self, key: Path) -> "ZarrChecksumNode": def add_leaf(self, path: Path, size: int, digest: str) -> None: """Add a leaf file to the tree.""" parent_node = self._get_path(path.parent) - parent_node.checksums.files.append(ZarrChecksum(name=path.name, size=size, digest=digest)) + parent_node.checksums.files.append(_ZarrChecksum(name=path.name, size=size, digest=digest)) def add_node(self, path: Path, size: int, digest: str, count: int) -> None: """Add an internal node to the tree.""" parent_node = self._get_path(path.parent) parent_node.checksums.directories.append( - ZarrChecksum( + _ZarrChecksum( name=path.name, size=size, digest=digest, @@ -237,7 +251,7 @@ def add_node(self, path: Path, size: int, digest: str, count: int) -> None: ) ) - def pop_deepest(self) -> "ZarrChecksumNode": + def pop_deepest(self) -> "_ZarrChecksumNode": """ Returns the node with the highest priority for processing next. @@ -248,11 +262,11 @@ def pop_deepest(self) -> "ZarrChecksumNode": del self._path_map[node.path] return node - def process(self) -> "ZarrDirectoryDigest": + def process(self) -> "_ZarrDirectoryDigest": """Process the tree, returning the resulting top level digest.""" # Begin with empty root node, so that if no files are present, the empty checksum is returned - node = ZarrChecksumNode(path=Path("."), checksums=ZarrChecksumManifest()) + node = _ZarrChecksumNode(path=Path("."), checksums=_ZarrChecksumManifest()) while not self.empty: # Get the next directory to process @@ -278,7 +292,7 @@ def process(self) -> "ZarrDirectoryDigest": @dataclass -class ZarrChecksumNode: +class _ZarrChecksumNode: """ A node in the ZarrChecksumTree. @@ -290,14 +304,14 @@ class ZarrChecksumNode: """ path: Path - checksums: "ZarrChecksumManifest" + checksums: "_ZarrChecksumManifest" - def __lt__(self, other: "ZarrChecksumNode") -> bool: + def __lt__(self, other: "_ZarrChecksumNode") -> bool: return str(self.path) < str(other.path) @dataclass -class ZarrChecksumManifest: +class _ZarrChecksumManifest: """ For a directory in the Zarr archive (i.e. a node in the heap queue), we maintain a manifest of the checksums for all files and directories @@ -306,14 +320,14 @@ class ZarrChecksumManifest: This data is then used to calculate the checksum of a directory. """ - directories: list["ZarrChecksum"] = field(default_factory=list) - files: list["ZarrChecksum"] = field(default_factory=list) + directories: list["_ZarrChecksum"] = field(default_factory=list) + files: list["_ZarrChecksum"] = field(default_factory=list) @property def is_empty(self) -> bool: return not (self.files or self.directories) - def generate_digest(self) -> "ZarrDirectoryDigest": + def generate_digest(self) -> "_ZarrDirectoryDigest": """Generate an aggregated digest for the provided files/directories.""" # Sort everything to ensure the checksum is deterministic @@ -333,12 +347,12 @@ def generate_digest(self) -> "ZarrDirectoryDigest": md5 = hashlib.md5(json.encode("utf-8")).hexdigest() # Construct and return - return ZarrDirectoryDigest(md5=md5, count=count, size=size) + return _ZarrDirectoryDigest(md5=md5, count=count, size=size) @total_ordering @dataclass -class ZarrChecksum: +class _ZarrChecksum: """ The data used to compute the checksum for a file or directory in a Zarr Archive. @@ -351,12 +365,12 @@ class ZarrChecksum: count: int = 0 # To make this class sortable - def __lt__(self, other: "ZarrChecksum") -> bool: + def __lt__(self, other: "_ZarrChecksum") -> bool: return self.name < other.name @dataclass -class ZarrDirectoryDigest: +class _ZarrDirectoryDigest: """ The digest for a directory in a Zarr Archive. @@ -369,7 +383,7 @@ class ZarrDirectoryDigest: size: int @classmethod - def parse(cls, checksum: str | None) -> "ZarrDirectoryDigest": + def parse(cls, checksum: str | None) -> "_ZarrDirectoryDigest": if checksum is None: return cls.parse(EMPTY_CHECKSUM) @@ -389,4 +403,4 @@ def digest(self) -> str: # The "null" zarr checksum -EMPTY_CHECKSUM = ZarrChecksumManifest().generate_digest().digest +EMPTY_CHECKSUM = _ZarrChecksumManifest().generate_digest().digest diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 9ca356b2..af3355db 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -622,7 +622,7 @@ def upload_dataset( hub_response.raise_for_status() # Step 3: Upload any associated Zarr archive - if dataset.zarr_root is not None: + if dataset.uses_zarr: with tmp_attribute_change(self.settings, "default_timeout", timeout): # Copy the Zarr archive to the hub dest = self.open_zarr_file( diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index 42dc3916..f6a3fc4d 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -41,43 +41,43 @@ from polaris.dataset.zarr._checksum import ( EMPTY_CHECKSUM, InvalidZarrChecksum, - ZarrChecksum, - ZarrChecksumManifest, - ZarrChecksumTree, - ZarrDirectoryDigest, + _ZarrChecksum, + _ZarrChecksumManifest, + _ZarrChecksumTree, + _ZarrDirectoryDigest, compute_zarr_checksum, ) def test_generate_digest() -> None: - manifest = ZarrChecksumManifest( - directories=[ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1-1", name="a/b", size=1)], - files=[ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-0-1", name="b", size=1)], + manifest = _ZarrChecksumManifest( + directories=[_ZarrChecksum(digest="a7e86136543b019d72468ceebf71fb8e-1-1", name="a/b", size=1)], + files=[_ZarrChecksum(digest="92eb5ffee6ae2fec3ad71c777531578f-0-1", name="b", size=1)], ) assert manifest.generate_digest().digest == "9c5294e46908cf397cb7ef53ffc12efc-1-2" def test_zarr_checksum_sort_order() -> None: # The a < b in the name should take precedence over z > y in the md5 - a = ZarrChecksum(name="a", digest="z", size=3) - b = ZarrChecksum(name="b", digest="y", size=4) + a = _ZarrChecksum(name="a", digest="z", size=3) + b = _ZarrChecksum(name="b", digest="y", size=4) assert sorted([b, a]) == [a, b] def test_parse_zarr_directory_digest() -> None: # Parse valid - ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481-38757151179") - ZarrDirectoryDigest.parse(None) + _ZarrDirectoryDigest.parse("c228464f432c4376f0de6ddaea32650c-37481-38757151179") + _ZarrDirectoryDigest.parse(None) # Ensure exception is raised with pytest.raises(InvalidZarrChecksum): - ZarrDirectoryDigest.parse("asd") + _ZarrDirectoryDigest.parse("asd") with pytest.raises(InvalidZarrChecksum): - ZarrDirectoryDigest.parse("asd-0--0") + _ZarrDirectoryDigest.parse("asd-0--0") def test_pop_deepest() -> None: - tree = ZarrChecksumTree() + tree = _ZarrChecksumTree() tree.add_leaf(Path("a/b"), size=1, digest="asd") tree.add_leaf(Path("a/b/c"), size=1, digest="asd") node = tree.pop_deepest() @@ -90,12 +90,12 @@ def test_pop_deepest() -> None: def test_process_empty_tree() -> None: - tree = ZarrChecksumTree() + tree = _ZarrChecksumTree() assert tree.process().digest == EMPTY_CHECKSUM def test_process_tree() -> None: - tree = ZarrChecksumTree() + tree = _ZarrChecksumTree() tree.add_leaf(Path("a/b"), size=1, digest="9dd4e461268c8034f5c8564e155c67a6") tree.add_leaf(Path("c"), size=1, digest="415290769594460e2e485922904f345d") checksum = tree.process() From 18fd500407469d57ec8c645fd98c7bc18d5961fe Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Wed, 3 Jul 2024 14:53:07 -0400 Subject: [PATCH 13/29] WIP: Integration with Hub --- polaris/dataset/_dataset.py | 15 ++------------- polaris/dataset/zarr/_checksum.py | 6 ++++-- polaris/hub/client.py | 9 ++++++--- 3 files changed, 12 insertions(+), 18 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 762ccb39..2153b5c9 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -90,7 +90,7 @@ class Dataset(BaseArtifactModel): _zarr_root: Optional[zarr.Group] = PrivateAttr(None) _zarr_data: Optional[MutableMapping[str, np.ndarray]] = PrivateAttr(None) _md5sum: Optional[str] = PrivateAttr(None) - _zarr_md5sum_manifest: Optional[Dict[str, ZarrFileChecksum]] = PrivateAttr(None) + _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) _client = PrivateAttr(None) # Optional[PolarisHubClient] @field_validator("table") @@ -223,19 +223,8 @@ def verify_checksum(self, md5sum: Optional[str] = None): def md5sum(self) -> str: """Lazily compute the checksum once needed.""" if self._md5sum is None: - self._md5sum, self._leaf_to_md5sum = self._compute_checksum(self.table, self.zarr_root_path) - return self._md5sum - - @computed_field - @property - def zarr_md5sum_manifest(self) -> Optional[Dict[str, ZarrFileChecksum]]: - """ - For Zarr archives, the mapping from all files to their checksum is used by the Hub - to verify data integrity on upload. - """ - if self._zarr_md5sum_manifest is None and self._md5sum is None: self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) - return self._zarr_md5sum_manifest + return self._md5sum @property def client(self): diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 620def51..44ee6a95 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -114,7 +114,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # Find all files below the root leaves = fs.find(zarr_root_path, detail=True) - zarr_md5sum_manifest = {} + zarr_md5sum_manifest = [] for file in tqdm(leaves.values(), desc="Finding all files in the Zarr archive"): path = file["name"] @@ -142,7 +142,7 @@ def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: # We persist the checksums for leaf nodes separately, # because this is what the Hub needs to verify data integrity. - zarr_md5sum_manifest[str(relpath)] = ZarrFileChecksum(md5sum=digest, size=size) + zarr_md5sum_manifest.append(ZarrFileChecksum(path=str(relpath), md5sum=digest, size=size)) # Compute digest return tree.process().digest, zarr_md5sum_manifest @@ -153,10 +153,12 @@ class ZarrFileChecksum(BaseModel): This data is sent to the Hub to verify the integrity of the Zarr archive on upload. Attributes: + path: The path of the file relative to the Zarr root. md5sum: The md5sum of the file. size: The size of the file in bytes. """ + path: str md5sum: str size: int diff --git a/polaris/hub/client.py b/polaris/hub/client.py index af3355db..72bdc340 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -564,19 +564,20 @@ def upload_dataset( dataset_json["zarrRootPath"] = f"{PolarisFileSystem.protocol}://data.zarr" # Uploading a dataset is a three-step process. - # 1. Upload the dataset meta data to the hub and prepare the hub to receive the parquet file + # 1. Upload the dataset meta data to the hub and prepare the hub to receive the data # 2. Upload the parquet file to the hub # 3. Upload the associated Zarr archive # TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub? - # Write the parquet file directly to a buffer + # Prepare the parquet file buffer = BytesIO() dataset.table.to_parquet(buffer, engine="auto") parquet_size = len(buffer.getbuffer()) parquet_md5 = md5(buffer.getbuffer()).hexdigest() # Step 1: Upload meta-data - # Instead of directly uploading the table, we announce to the hub that we intend to upload one. + # Instead of directly uploading the data, we announce to the hub that we intend to upload it. + # We do so separately for the Zarr archive and Parquet file. url = f"/dataset/{dataset.artifact_id}" response = self._base_request_to_hub( url=url, @@ -587,6 +588,7 @@ def upload_dataset( "fileType": "parquet", "md5sum": parquet_md5, }, + "zarrContent": [md5sum.model_dump() for md5sum in dataset._zarr_md5sum_manifest], "access": access, **dataset_json, }, @@ -607,6 +609,7 @@ def upload_dataset( if hub_response.status_code == 307: # If the hub returns a 307 redirect, we need to follow it to get the signed URL hub_response_body = hub_response.json() + # Upload the data to the cloudflare url bucket_response = self.request( url=hub_response_body["url"], From bcdb10d56838db1a5858875e2e5b1a5e02f8813a Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 3 Jul 2024 17:22:19 -0400 Subject: [PATCH 14/29] WIP: Trying to get Zarr up- and downloads to work again... Seems to be related to the use of the MemoryMappedStore! --- polaris/dataset/_dataset.py | 24 ++++++++++++++++-------- polaris/dataset/converters/_zarr.py | 2 +- polaris/dataset/zarr/_checksum.py | 4 ++-- polaris/hub/client.py | 25 +++++++++++++++---------- polaris/hub/polarisfs.py | 11 +---------- tests/test_zarr_checksum.py | 2 +- 6 files changed, 36 insertions(+), 32 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 2153b5c9..30955240 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -27,7 +27,12 @@ from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError -from polaris.utils.types import AccessType, HttpUrlString, HubOwner, SupportedLicenseType +from polaris.utils.types import ( + AccessType, + HttpUrlString, + HubOwner, + SupportedLicenseType, +) # Constants _SUPPORTED_TABLE_EXTENSIONS = ["parquet"] @@ -92,6 +97,7 @@ class Dataset(BaseArtifactModel): _md5sum: Optional[str] = PrivateAttr(None) _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) _client = PrivateAttr(None) # Optional[PolarisHubClient] + _warn_about_remote_zarr: bool = PrivateAttr(True) # Optional[PolarisHubClient] @field_validator("table") def _validate_table(cls, v): @@ -277,14 +283,16 @@ def zarr_root(self): # We open the archive in read-only mode if it is saved on the Hub saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) - saved_remote = saved_on_hub or not dmfs.is_local_path(self.zarr_root_path) - if saved_remote: - logger.warning( - f"You're loading data from a remote location. " - f"To speed up this process, consider caching the dataset first " - f"using {self.__class__.__name__}.cache()" - ) + if self._warn_about_remote_zarr: + saved_remote = saved_on_hub or not dmfs.is_local_path(self.zarr_root_path) + + if saved_remote: + logger.warning( + f"You're loading data from a remote location. " + f"To speed up this process, consider caching the dataset first " + f"using {self.__class__.__name__}.cache()" + ) try: if saved_on_hub: diff --git a/polaris/dataset/converters/_zarr.py b/polaris/dataset/converters/_zarr.py index 5ed706d0..4380325c 100644 --- a/polaris/dataset/converters/_zarr.py +++ b/polaris/dataset/converters/_zarr.py @@ -35,7 +35,7 @@ def convert(self, path: str, factory: "DatasetFactory") -> FactoryProduct: raise ValueError("The root of the zarr hierarchy should only contain arrays.") # Copy to the source zarr, so everything is in one place - zarr.copy_all(source=src, dest=factory.zarr_root) + zarr.copy_store(source=src.store, dest=factory.zarr_root.store, if_exists="skip") # Construct the table # Parse any group into a column diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 44ee6a95..c68f6a3d 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -38,7 +38,7 @@ from functools import total_ordering from json import dumps from pathlib import Path -from typing import Dict, Tuple +from typing import List, Tuple import fsspec import fsspec.utils @@ -52,7 +52,7 @@ ZARR_DIGEST_PATTERN = "([0-9a-f]{32})-([0-9]+)-([0-9]+)" -def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, Dict[str, str]]: +def compute_zarr_checksum(zarr_root_path: str) -> Tuple[str, List["ZarrFileChecksum"]]: r""" Implements an algorithm to compute the Zarr checksum. diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 72bdc340..ad81beb7 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -365,10 +365,11 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b if verify_checksum and checksum is not None: if dataset.uses_zarr: logger.info("Skipping checksum verification, because the dataset is stored remotely.") - dataset.verify_checksum(md5sum=checksum) - elif not verify_checksum: - dataset._md5sum = checksum - dataset._leaf_to_md5sum = response.get("leafToMd5Sum", None) + else: + dataset.verify_checksum(md5sum=checksum) + + dataset._md5sum = checksum + dataset._zarr_md5sum_manifest = response.get("zarrContent", None) return dataset @@ -456,14 +457,18 @@ def get_benchmark( benchmark = benchmark_cls(**response) checksum = response.pop("md5Sum", None) - if verify_checksum and checksum is not None and checksum != benchmark.md5sum: - raise PolarisChecksumError( - "The benchmark checksum does not match what was specified in the meta-data. " - f"{checksum} != {benchmark.md5sum}" - ) + if verify_checksum and checksum is not None: + if benchmark.dataset.uses_zarr: + logger.info("Skipping checksum verification, because the dataset is stored remotely.") + elif checksum != benchmark.md5sum: + raise PolarisChecksumError( + "The benchmark checksum does not match what was specified in the meta-data. " + f"{checksum} != {benchmark.md5sum}" + ) elif not verify_checksum: benchmark._md5sum = checksum - return + + return benchmark def upload_results( self, diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index ea201385..c3a799d3 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -1,5 +1,3 @@ -import hashlib -from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import fsspec @@ -202,14 +200,7 @@ def pipe_file( hub_response_body = response.json() signed_url = hub_response_body["url"] - sha256_hash = hashlib.sha256(content).hexdigest() - - headers = { - "Content-Type": "application/octet-stream", - "x-amz-content-sha256": sha256_hash, - "x-amz-date": datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ"), - **hub_response_body["headers"], - } + headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} response = self.polaris_client.request( url=signed_url, diff --git a/tests/test_zarr_checksum.py b/tests/test_zarr_checksum.py index f6a3fc4d..9bcf1765 100644 --- a/tests/test_zarr_checksum.py +++ b/tests/test_zarr_checksum.py @@ -126,7 +126,7 @@ def test_zarr_leaf_to_checksum(zarr_archive): # Check the basic structure - Each key corresponds to a file in the zarr archive assert len(leaf_to_checksum) == len(root.store) - assert all(k in root.store for k in leaf_to_checksum.keys()) + assert all(k.path in root.store for k in leaf_to_checksum) def test_zarr_checksum_fails_for_remote_storage(zarr_archive): From 12874c5494e4a1aa3476da14fdc5fa2423e92d19 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 3 Jul 2024 17:41:18 -0400 Subject: [PATCH 15/29] WIP: Further debugging of Zarr datasets --- polaris/dataset/_dataset.py | 28 ++++++++++++++++++++++------ polaris/hub/client.py | 2 +- 2 files changed, 23 insertions(+), 7 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 30955240..19c68a43 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -25,6 +25,7 @@ from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR +from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError from polaris.utils.types import ( @@ -32,6 +33,7 @@ HttpUrlString, HubOwner, SupportedLicenseType, + ZarrConflictResolution, ) # Constants @@ -413,7 +415,11 @@ def from_json(cls, path: str): data.pop("cache_dir", None) return cls.model_validate(data) - def to_json(self, destination: str) -> str: + def to_json( + self, + destination: str, + if_exists: ZarrConflictResolution = "replace", + ) -> str: """ Save the dataset to a destination directory as a JSON file. @@ -428,6 +434,8 @@ def to_json(self, destination: str) -> str: Args: destination: The _directory_ to save the associated data to. + if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw + an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files. Returns: The path to the JSON file. @@ -443,11 +451,19 @@ def to_json(self, destination: str) -> str: # Copy over Zarr data to the destination if self.uses_zarr: - # Zarr has the `copy_all` function, but this does not copy the .zmetadata file - # and creates .zattrs files even if there aren't any user attributes. This - # messes with our checksums. So we copy the files manually. - dmfs.copy_dir(self.zarr_root_path, new_zarr_root_path) - serialized["zarr_root_path"] = new_zarr_root_path + with tmp_attribute_change(self, "_warn_about_remote_zarr", False): + logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") + + dest = zarr.open(new_zarr_root_path, "w") + + zarr.copy_store( + source=self.zarr_root.store.store, + dest=dest.store, + log=logger.debug, + if_exists=if_exists, + ) + + self._warn_about_remote_zarr = True self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: diff --git a/polaris/hub/client.py b/polaris/hub/client.py index ad81beb7..9b02231e 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -652,7 +652,7 @@ def upload_dataset( zarr.copy_store( source=dataset.zarr_root.store.store, dest=dest.store, - log=logger.info, + log=logger.debug, if_exists=if_exists, ) From 3fb065c245ed93d31039f2eef43c5ae5d654a526 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Wed, 3 Jul 2024 22:48:03 -0400 Subject: [PATCH 16/29] Removed caching from the PolarisFS ls() endpoint and changed verify_checksum logic --- polaris/__init__.py | 10 +++++++ polaris/benchmark/_base.py | 22 +++++++++++++++- polaris/hub/client.py | 53 +++++++++++++++++++------------------- polaris/hub/polarisfs.py | 4 --- polaris/loader/load.py | 26 +++++++++++++++---- 5 files changed, 79 insertions(+), 36 deletions(-) diff --git a/polaris/__init__.py b/polaris/__init__.py index a9266859..9e51b460 100644 --- a/polaris/__init__.py +++ b/polaris/__init__.py @@ -1,4 +1,14 @@ +import os +import sys + +from loguru import logger + from ._version import __version__ from .loader import load_benchmark, load_dataset __all__ = ["load_dataset", "load_benchmark", "__version__"] + +# Configure the default logging level +os.environ["LOGURU_LEVEL"] = os.environ.get("LOGURU_LEVEL", "INFO") +logger.remove() +logger.add(sys.stderr) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 15489560..7adfba26 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -23,7 +23,7 @@ from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidBenchmarkError +from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError from polaris.utils.misc import listit from polaris.utils.types import ( AccessType, @@ -296,6 +296,26 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): checksum = hash_fn.hexdigest() return checksum + def verify_checksum(self, md5sum: Optional[str] = None): + """ + Recomputes the checksum and verifies whether it matches the stored checksum. + """ + if md5sum is None: + md5sum = self._md5sum + if md5sum is None: + raise RuntimeError( + "No checksum to verify against. Specify either the md5sum parameter or " + "store the checksum in the dataset._md5sum attribute." + ) + + # Temporarily reset + # Calling self.md5sum will recompute the checksum and set it again + self._md5sum = None + if self.md5sum != md5sum: + raise PolarisChecksumError( + f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" + ) + @computed_field @property def md5sum(self) -> Optional[str]: diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 9b02231e..6babacea 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -33,7 +33,6 @@ from polaris.utils.context import tmp_attribute_change from polaris.utils.errors import ( InvalidDatasetError, - PolarisChecksumError, PolarisHubError, PolarisUnauthorizedError, ) @@ -179,6 +178,15 @@ def _normalize_owner( return artifact_owner if isinstance(artifact_owner, HubOwner) else HubOwner(slug=artifact_owner) + @staticmethod + def _normalize_verify_checksum( + verify_checksum: Optional[bool], + dataset: Dataset, + ): + if verify_checksum is not None: + return verify_checksum + return dataset._md5sum is not None and not dataset.uses_zarr + # ========================= # Overrides # ========================= @@ -331,13 +339,16 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: dataset_list = [bm["artifactId"] for bm in response["data"]] return dataset_list - def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True) -> Dataset: + def get_dataset( + self, owner: Union[str, HubOwner], name: str, verify_checksum: Optional[bool] = None + ) -> Dataset: """Load a dataset from the Polaris Hub. Args: owner: The owner of the dataset. Can be either a user or organization from the Polaris Hub. name: The name of the dataset. - verify_checksum: Whether to use the checksum to verify the integrity of the dataset. + verify_checksum: Whether to use the checksum to verify the integrity of the dataset. If None, + will infer a practical default based on the dataset's storage location. Returns: A `Dataset` instance, if it exists. @@ -360,17 +371,12 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: b response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) dataset = Dataset(**response) - checksum = response.pop("md5Sum", None) + verify_checksum = self._normalize_verify_checksum(verify_checksum, dataset) - if verify_checksum and checksum is not None: - if dataset.uses_zarr: - logger.info("Skipping checksum verification, because the dataset is stored remotely.") - else: - dataset.verify_checksum(md5sum=checksum) + if verify_checksum: + dataset.verify_checksum() - dataset._md5sum = checksum dataset._zarr_md5sum_manifest = response.get("zarrContent", None) - return dataset def open_zarr_file( @@ -383,7 +389,8 @@ def open_zarr_file( name: Name of the dataset. path: Path to the Zarr file within the dataset. mode: The mode in which the file is opened. - as_consolidated: Whether to open the store with consolidated metadata for optimized reading. This is only applicable in 'r' and 'r+' modes. + as_consolidated: Whether to open the store with consolidated metadata for optimized reading. + This is only applicable in 'r' and 'r+' modes. Returns: The Zarr object representing the dataset. @@ -425,14 +432,15 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: return benchmarks_list def get_benchmark( - self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True + self, owner: Union[str, HubOwner], name: str, verify_checksum: Optional[bool] = None ) -> BenchmarkSpecification: """Load a benchmark from the Polaris Hub. Args: owner: The owner of the benchmark. Can be either a user or organization from the Polaris Hub. name: The name of the benchmark. - verify_checksum: Whether to use the checksum to verify the integrity of the dataset. + verify_checksum: Whether to use the checksum to verify the integrity of the dataset. If None, + will infer a practical default based on the dataset's storage location. Returns: A `BenchmarkSpecification` instance, if it exists. @@ -455,18 +463,11 @@ def get_benchmark( ) benchmark = benchmark_cls(**response) - checksum = response.pop("md5Sum", None) - - if verify_checksum and checksum is not None: - if benchmark.dataset.uses_zarr: - logger.info("Skipping checksum verification, because the dataset is stored remotely.") - elif checksum != benchmark.md5sum: - raise PolarisChecksumError( - "The benchmark checksum does not match what was specified in the meta-data. " - f"{checksum} != {benchmark.md5sum}" - ) - elif not verify_checksum: - benchmark._md5sum = checksum + + verify_checksum = self._normalize_verify_checksum(verify_checksum, benchmark.dataset) + + if verify_checksum: + benchmark.verify_checksum() return benchmark diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index c3a799d3..a43dd014 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -85,10 +85,6 @@ def ls( if timeout is None: timeout = self.default_timeout - cached_listings = self._ls_from_cache(path) - if cached_listings is not None: - return cached_listings if detail else [d["name"] for d in cached_listings] - ls_path = self.sep.join([self.base_path, "ls", path]) # GET request to Polaris Hub to list objects in path diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 5bebf291..dc5cacd2 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -1,4 +1,5 @@ import json +from typing import Optional import fsspec from datamol.utils import fs @@ -11,7 +12,7 @@ from polaris.hub.client import PolarisHubClient -def load_dataset(path: str, verify_checksum: bool = True) -> Dataset: +def load_dataset(path: str, verify_checksum: Optional[bool] = None) -> Dataset: """ Loads a Polaris dataset. @@ -37,12 +38,20 @@ def load_dataset(path: str, verify_checksum: bool = True) -> Dataset: client = PolarisHubClient() return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum) + # Load from local file if extension == "json": - return Dataset.from_json(path) - return create_dataset_from_file(path) + dataset = Dataset.from_json(path) + else: + dataset = create_dataset_from_file(path) + # Verify checksum if requested + if verify_checksum: + dataset.verify_checksum() -def load_benchmark(path: str, verify_checksum: bool = True): + return dataset + + +def load_benchmark(path: str, verify_checksum: Optional[bool] = None): """ Loads a Polaris benchmark. @@ -75,4 +84,11 @@ def load_benchmark(path: str, verify_checksum: bool = True): # e.g. we might end up with a single class per benchmark. is_single_task = isinstance(data["target_cols"], str) or len(data["target_cols"]) == 1 cls = SingleTaskBenchmarkSpecification if is_single_task else MultiTaskBenchmarkSpecification - return cls.from_json(path) + + benchmark = cls.from_json(path) + + # Verify checksum if requested + if verify_checksum: + benchmark.verify_checksum() + + return benchmark From 6909de40d6798c5780affdbf22e8be254a26ad70 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 4 Jul 2024 10:59:02 -0400 Subject: [PATCH 17/29] Minor changes in line with Hub changes --- polaris/dataset/_adapters.py | 1 + polaris/dataset/_dataset.py | 8 ++++++++ polaris/dataset/zarr/_checksum.py | 5 ++++- 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/polaris/dataset/_adapters.py b/polaris/dataset/_adapters.py index b97b55af..89fcee14 100644 --- a/polaris/dataset/_adapters.py +++ b/polaris/dataset/_adapters.py @@ -1,4 +1,5 @@ from enum import Enum, auto, unique + import datamol as dm # Map of conversion operations which can be applied to dataset columns diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 19c68a43..1cc7d867 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -234,6 +234,14 @@ def md5sum(self) -> str: self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) return self._md5sum + @computed_field + @property + def zarr_md5sum_manifest(self) -> str: + """Lazily compute the checksum once needed.""" + if self._zarr_md5sum_manifest is None and self._md5sum is None: + self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) + return self._zarr_md5sum_manifest + @property def client(self): """The Polaris Hub client used to interact with the Polaris Hub.""" diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index c68f6a3d..56b80dbf 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -44,7 +44,8 @@ import fsspec.utils import zarr import zarr.errors -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel from tqdm import tqdm from polaris.utils.errors import InvalidZarrChecksum @@ -158,6 +159,8 @@ class ZarrFileChecksum(BaseModel): size: The size of the file in bytes. """ + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) + path: str md5sum: str size: int From 012d8c244f01a16a005986542dd2ce3776f76acb Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 4 Jul 2024 11:49:31 -0400 Subject: [PATCH 18/29] Set md5sum from the Hub --- polaris/hub/client.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 6babacea..46567015 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -375,6 +375,8 @@ def get_dataset( if verify_checksum: dataset.verify_checksum() + else: + dataset._md5sum = response["md5Sum"] dataset._zarr_md5sum_manifest = response.get("zarrContent", None) return dataset @@ -468,6 +470,8 @@ def get_benchmark( if verify_checksum: benchmark.verify_checksum() + else: + benchmark._md5sum = response["md5Sum"] return benchmark From 367b43df4a3e555d85d8413630f10e693b6d0199 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 4 Jul 2024 17:47:35 -0400 Subject: [PATCH 19/29] Fixed bug in saving the md5Sum --- polaris/hub/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 46567015..b48aa122 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -596,7 +596,7 @@ def upload_dataset( "tableContent": { "size": parquet_size, "fileType": "parquet", - "md5sum": parquet_md5, + "md5Sum": parquet_md5, }, "zarrContent": [md5sum.model_dump() for md5sum in dataset._zarr_md5sum_manifest], "access": access, From 71bf63fc8a63f6d65033b7ca401b6bfcc55b6a47 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 4 Jul 2024 18:10:37 -0400 Subject: [PATCH 20/29] Use request instead of fsspec.open to support custom headers in signed URL --- polaris/hub/polarisfs.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index a43dd014..70632414 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -139,11 +139,20 @@ def cat_file( if response.status_code != 307: raise PolarisHubError("Could not get signed URL from Polaris Hub.") - signed_url = response.json()["url"] + hub_response_body = response.json() + signed_url = hub_response_body["url"] + + headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} - with fsspec.open(signed_url, "rb", **kwargs) as f: - data = f.read() - return data[start:end] + response = self.polaris_client.request( + url=signed_url, + method="GET", + auth=None, + headers=headers, + timeout=timeout, + ) + response.raise_for_status() + return response.content[start:end] def rm(self, path: str, recursive: bool = False, maxdepth: Optional[int] = None) -> None: """Remove a file or directory from the Polaris dataset. From c67b6b584074d83509d46e1114dfaf689d7e22af Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 5 Jul 2024 10:04:46 -0400 Subject: [PATCH 21/29] Verify checksum on downloading a single chunk --- polaris/__init__.py | 2 +- polaris/dataset/_dataset.py | 2 +- polaris/hub/client.py | 9 ++++++--- polaris/hub/polarisfs.py | 20 +++++++++++++++++++- polaris/hub/settings.py | 1 + 5 files changed, 28 insertions(+), 6 deletions(-) diff --git a/polaris/__init__.py b/polaris/__init__.py index 9e51b460..ddb0f44a 100644 --- a/polaris/__init__.py +++ b/polaris/__init__.py @@ -11,4 +11,4 @@ # Configure the default logging level os.environ["LOGURU_LEVEL"] = os.environ.get("LOGURU_LEVEL", "INFO") logger.remove() -logger.add(sys.stderr) +logger.add(sys.stderr, level=os.environ["LOGURU_LEVEL"]) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 1cc7d867..92bd4fb4 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -236,7 +236,7 @@ def md5sum(self) -> str: @computed_field @property - def zarr_md5sum_manifest(self) -> str: + def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: """Lazily compute the checksum once needed.""" if self._zarr_md5sum_manifest is None and self._md5sum is None: self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index b48aa122..9b81db08 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -16,7 +16,7 @@ from authlib.integrations.httpx_client import OAuth2Client, OAuthError from authlib.oauth2.client import OAuth2Client as _OAuth2Client from datamol.utils import fs -from httpx import HTTPStatusError +from httpx import HTTPStatusError, Response from httpx._types import HeaderTypes, URLTypes from loguru import logger @@ -187,6 +187,11 @@ def _normalize_verify_checksum( return verify_checksum return dataset._md5sum is not None and not dataset.uses_zarr + def get_metadata_from_response(self, response: Response, key: str) -> Optional[str]: + """Get custom metadata saved to the R2 object from the headers.""" + key = f"{self.settings.custom_metadata_prefix}{key}" + return response.headers.get(key) + # ========================= # Overrides # ========================= @@ -377,8 +382,6 @@ def get_dataset( dataset.verify_checksum() else: dataset._md5sum = response["md5Sum"] - - dataset._zarr_md5sum_manifest = response.get("zarrContent", None) return dataset def open_zarr_file( diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index 70632414..9bc8ab85 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -1,6 +1,8 @@ +from hashlib import md5 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import fsspec +from loguru import logger from polaris.utils.errors import PolarisHubError from polaris.utils.types import TimeoutTypes @@ -152,7 +154,23 @@ def cat_file( timeout=timeout, ) response.raise_for_status() - return response.content[start:end] + response_content = response.content + + # Verify the checksum on download + expected_md5sum = self.polaris_client.get_metadata_from_response(response, "md5sum") + if expected_md5sum is None: + raise PolarisHubError("MD5 checksum not found in response headers.") + logger.debug(f"MD5 checksum found in response headers: {expected_md5sum}.") + + md5sum = md5(response_content).hexdigest() + logger.debug(f"MD5 checksum computed for response content: {md5sum}.") + + if md5sum != expected_md5sum: + raise PolarisHubError( + f"MD5 checksum verification failed. Expected {expected_md5sum}, got {md5sum}." + ) + + return response_content[start:end] def rm(self, path: str, recursive: bool = False, maxdepth: Optional[int] = None) -> None: """Remove a file or directory from the Polaris dataset. diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index c23dc7ee..9c2343b4 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -37,6 +37,7 @@ class PolarisHubSettings(BaseSettings): scopes: str = "profile email" client_id: str = "agQP2xVM6JqMHvGc" ca_bundle: Optional[Union[str, bool]] = None + custom_metadata_prefix: str = "X-Amz-Meta-" default_timeout: TimeoutTypes = (10, 200) From 1e37a464fe0e7f87aceb960f0b7f15ba52fba47e Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 5 Jul 2024 13:57:20 -0400 Subject: [PATCH 22/29] Self review --- polaris/benchmark/_base.py | 4 +--- polaris/dataset/_dataset.py | 9 ++++----- polaris/dataset/zarr/_checksum.py | 1 - polaris/loader/load.py | 7 +++---- tests/conftest.py | 2 +- tests/test_dataset.py | 29 ++++++++++------------------- 6 files changed, 19 insertions(+), 33 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 7adfba26..8e535cf5 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -86,8 +86,6 @@ class BenchmarkSpecification(BaseArtifactModel): split: The predefined train-test split to use for evaluation. metrics: The metrics to use for evaluating performance main_metric: The main metric used to rank methods. If `None`, the first of the `metrics` field. - md5sum: The checksum is used to verify the version of the dataset specification. If specified, it will - raise an error if the specified checksum doesn't match the computed checksum. readme: Markdown text that can be used to provide a formatted description of the benchmark. If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI as it provides a rich text editor for writing markdown. @@ -305,7 +303,7 @@ def verify_checksum(self, md5sum: Optional[str] = None): if md5sum is None: raise RuntimeError( "No checksum to verify against. Specify either the md5sum parameter or " - "store the checksum in the dataset._md5sum attribute." + "store the checksum in the benchmark._md5sum attribute." ) # Temporarily reset diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 92bd4fb4..751eb528 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -4,7 +4,6 @@ from typing import Dict, List, MutableMapping, Optional, Tuple, Union import fsspec -import fsspec.utils import numpy as np import pandas as pd import zarr @@ -99,7 +98,7 @@ class Dataset(BaseArtifactModel): _md5sum: Optional[str] = PrivateAttr(None) _zarr_md5sum_manifest: List[ZarrFileChecksum] = PrivateAttr(default_factory=list) _client = PrivateAttr(None) # Optional[PolarisHubClient] - _warn_about_remote_zarr: bool = PrivateAttr(True) # Optional[PolarisHubClient] + _warn_about_remote_zarr: bool = PrivateAttr(True) @field_validator("table") def _validate_table(cls, v): @@ -150,7 +149,8 @@ def _validate_model(cls, m: "Dataset"): # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - m.cache_dir = dmfs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, str(uuid.uuid4())) + dataset_id = m._md5sum if m._md5sum is not None else str(uuid.uuid4()) + m.cache_dir = dmfs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, dataset_id) dmfs.mkdir(m.cache_dir, exist_ok=True) @@ -303,6 +303,7 @@ def zarr_root(self): f"To speed up this process, consider caching the dataset first " f"using {self.__class__.__name__}.cache()" ) + self._warn_about_remote_zarr = False try: if saved_on_hub: @@ -471,8 +472,6 @@ def to_json( if_exists=if_exists, ) - self._warn_about_remote_zarr = True - self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: json.dump(serialized, f) diff --git a/polaris/dataset/zarr/_checksum.py b/polaris/dataset/zarr/_checksum.py index 56b80dbf..a06f9491 100644 --- a/polaris/dataset/zarr/_checksum.py +++ b/polaris/dataset/zarr/_checksum.py @@ -41,7 +41,6 @@ from typing import List, Tuple import fsspec -import fsspec.utils import zarr import zarr.errors from pydantic import BaseModel, ConfigDict diff --git a/polaris/loader/load.py b/polaris/loader/load.py index dc5cacd2..1389d99c 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -45,7 +45,7 @@ def load_dataset(path: str, verify_checksum: Optional[bool] = None) -> Dataset: dataset = create_dataset_from_file(path) # Verify checksum if requested - if verify_checksum: + if PolarisHubClient._normalize_verify_checksum(verify_checksum, dataset): dataset.verify_checksum() return dataset @@ -88,7 +88,6 @@ def load_benchmark(path: str, verify_checksum: Optional[bool] = None): benchmark = cls.from_json(path) # Verify checksum if requested - if verify_checksum: - benchmark.verify_checksum() - + if PolarisHubClient._normalize_verify_checksum(verify_checksum, benchmark.dataset): + benchmark.verify_checksum(md5sum=data["md5sum"]) return benchmark diff --git a/tests/conftest.py b/tests/conftest.py index b034d464..2170a62a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,7 +60,7 @@ def test_user_owner(): return HubOwner(userId="test-user", slug="test-user") -@pytest.fixture(scope="module") +@pytest.fixture(scope="function") def test_dataset(test_data, test_org_owner): dataset = Dataset( table=test_data, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 3e1dccf8..7369f22a 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -51,42 +51,33 @@ def test_dataset_checksum(test_dataset): # Make sure the `md5sum` is part of the model dump even if not initiated yet. # This is important for uploads to the Hub. - assert test_dataset._md5sum is None and "md5sum" in test_dataset.model_dump() - - original = test_dataset.md5sum - assert original is not None + assert test_dataset._md5sum is None + assert "md5sum" in test_dataset.model_dump() # Without any changes, same hash kwargs = test_dataset.model_dump() - assert Dataset(**kwargs).md5sum == original + assert Dataset(**kwargs) == test_dataset # With unimportant changes, same hash kwargs["name"] = "changed" kwargs["description"] = "changed" kwargs["source"] = "https://changed.com" - assert Dataset(**kwargs).md5sum == original + assert Dataset(**kwargs) == test_dataset # Check sensitivity to the row and column ordering kwargs["table"] = kwargs["table"].iloc[::-1] kwargs["table"] = kwargs["table"][kwargs["table"].columns[::-1]] - assert Dataset(**kwargs).md5sum == original - - def _check_for_failure(_kwargs): - assert Dataset(**_kwargs).md5sum != _kwargs["md5sum"] + assert Dataset(**kwargs) == test_dataset # Without any changes, but different hash - kwargs["md5sum"] = "invalid" - _check_for_failure(kwargs) + dataset = Dataset(**kwargs) + dataset._md5sum = "invalid" + assert dataset != test_dataset # With changes, but same hash - kwargs["md5sum"] = original + kwargs["md5sum"] = test_dataset.md5sum kwargs["table"] = kwargs["table"].iloc[:-1] - _check_for_failure(kwargs) - - # With changes, but no hash - kwargs["md5sum"] = None - dataset = Dataset(**kwargs) - assert dataset.md5sum is not None + assert Dataset(**kwargs) != test_dataset def test_dataset_from_zarr(zarr_archive, tmpdir): From 1c10763d6902fe3327da072eda247dee923a6922 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 5 Jul 2024 14:02:41 -0400 Subject: [PATCH 23/29] Trigger CICD --- polaris/loader/load.py | 1 + 1 file changed, 1 insertion(+) diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 1389d99c..79ccf9db 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -90,4 +90,5 @@ def load_benchmark(path: str, verify_checksum: Optional[bool] = None): # Verify checksum if requested if PolarisHubClient._normalize_verify_checksum(verify_checksum, benchmark.dataset): benchmark.verify_checksum(md5sum=data["md5sum"]) + return benchmark From 7fb571a9fabfbc5179643860e7b37c00ec2389eb Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 10 Jul 2024 19:05:25 -0400 Subject: [PATCH 24/29] Address PR feedback --- polaris/benchmark/_base.py | 39 ++++++++++++------- polaris/dataset/_dataset.py | 75 +++++++++++++++++++++---------------- polaris/hub/client.py | 41 ++++++++++---------- polaris/hub/polarisfs.py | 7 +++- polaris/loader/load.py | 13 ++++--- polaris/utils/misc.py | 19 +++++++++- polaris/utils/types.py | 5 +++ 7 files changed, 122 insertions(+), 77 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 8e535cf5..982618c9 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd from datamol.utils import fs +from loguru import logger from pydantic import ( Field, PrivateAttr, @@ -261,8 +262,7 @@ def _serialize_target_types(self, v): """Convert from enum to string to make sure it's serializable""" return {k: v.value for k, v in self.target_types.items()} - @staticmethod - def _compute_checksum(dataset, target_cols, input_cols, split, metrics): + def _compute_checksum(self): """ Computes a hash of the benchmark. @@ -270,16 +270,16 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): """ hash_fn = md5() - hash_fn.update(dataset.md5sum.encode("utf-8")) - for c in sorted(target_cols): + hash_fn.update(self.dataset.md5sum.encode("utf-8")) + for c in sorted(self.target_cols): hash_fn.update(c.encode("utf-8")) - for c in sorted(input_cols): + for c in sorted(self.input_cols): hash_fn.update(c.encode("utf-8")) - for m in sorted(metrics, key=lambda k: k.name): + for m in sorted(self.metrics, key=lambda k: k.name): hash_fn.update(m.name.encode("utf-8")) - if not isinstance(split[1], dict): - split = split[0], {"test": split[1]} + if not isinstance(self.split[1], dict): + split = self.split[0], {"test": self.split[1]} # Train set s = json.dumps(sorted(split[0])) @@ -301,10 +301,11 @@ def verify_checksum(self, md5sum: Optional[str] = None): if md5sum is None: md5sum = self._md5sum if md5sum is None: - raise RuntimeError( + logger.warning( "No checksum to verify against. Specify either the md5sum parameter or " - "store the checksum in the benchmark._md5sum attribute." + "store the checksum in the benchmark.md5sum attribute. Skipping!" ) + return # Temporarily reset # Calling self.md5sum will recompute the checksum and set it again @@ -318,12 +319,22 @@ def verify_checksum(self, md5sum: Optional[str] = None): @property def md5sum(self) -> Optional[str]: """Lazily compute the checksum once needed.""" - if self._md5sum is None: - self._md5sum = self._compute_checksum( - self.dataset, self.target_cols, self.input_cols, self.split, self.metrics - ) + if not self.has_md5sum: + self._md5sum = self._compute_checksum() return self._md5sum + @md5sum.setter + def md5sum(self, value: str): + """Set the checksum.""" + if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): + raise ValueError("The checksum should be a 32-character long MD5 hash.") + self._md5sum = value + + @property + def has_md5sum(self) -> Optional[str]: + """Lazily compute the checksum once needed.""" + return self._md5sum is not None + @computed_field @property def n_train_datapoints(self) -> int: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 751eb528..6388dd48 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -1,6 +1,7 @@ import json import uuid from hashlib import md5 +from pathlib import Path from typing import Dict, List, MutableMapping, Optional, Tuple, Union import fsspec @@ -24,7 +25,6 @@ from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR -from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError from polaris.utils.types import ( @@ -90,7 +90,7 @@ class Dataset(BaseArtifactModel): curation_reference: Optional[HttpUrlString] = None # Config - cache_dir: Optional[str] = None # Where to cache the data to if cache() is called. + cache_dir: Optional[Path] = None # Where to cache the data to if cache() is called. # Private attributes _zarr_root: Optional[zarr.Group] = PrivateAttr(None) @@ -149,11 +149,10 @@ def _validate_model(cls, m: "Dataset"): # Set the default cache dir if none and make sure it exists if m.cache_dir is None: - dataset_id = m._md5sum if m._md5sum is not None else str(uuid.uuid4()) - m.cache_dir = dmfs.join(DEFAULT_CACHE_DIR, _CACHE_SUBDIR, dataset_id) - - dmfs.mkdir(m.cache_dir, exist_ok=True) + dataset_id = m._md5sum if m.has_md5sum else str(uuid.uuid4()) + m.cache_dir = Path(DEFAULT_CACHE_DIR) / _CACHE_SUBDIR / dataset_id + m.cache_dir.mkdir(parents=True, exist_ok=True) return m @field_validator("default_adapters", mode="before") @@ -166,11 +165,7 @@ def _serialize_adapters(self, value: List[Adapter]): """Serializes the adapters""" return {k: v.name for k, v in value.items()} - @staticmethod - def _compute_checksum( - table: pd.DataFrame, - zarr_root_path: Optional[str] = None, - ): + def _compute_checksum(self): """Computes a hash of the dataset. This is meant to uniquely identify the dataset and can be used to verify the version. @@ -182,7 +177,7 @@ def _compute_checksum( hash_fn = md5() # Sort the columns s.t. the checksum is not sensitive to the column-ordering - df = table.copy(deep=True) + df = self.table.copy(deep=True) df = df[sorted(df.columns.tolist())] # Use the sum of the row-wise hashes s.t. the hash is insensitive to the row-ordering @@ -191,8 +186,8 @@ def _compute_checksum( # If the Zarr archive exists, we hash its contents too. zarr_md5sum_manifest = None - if zarr_root_path is not None: - zarr_hash, zarr_md5sum_manifest = compute_zarr_checksum(zarr_root_path) + if self.zarr_root_path is not None: + zarr_hash, zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) hash_fn.update(zarr_hash.encode()) checksum = hash_fn.hexdigest() @@ -213,10 +208,11 @@ def verify_checksum(self, md5sum: Optional[str] = None): if md5sum is None: md5sum = self._md5sum if md5sum is None: - raise RuntimeError( + logger.warning( "No checksum to verify against. Specify either the md5sum parameter or " - "store the checksum in the dataset._md5sum attribute." + "store the checksum in the dataset.md5sum attribute." ) + return # Temporarily reset # Calling self.md5sum will recompute the checksum and set it again @@ -230,16 +226,28 @@ def verify_checksum(self, md5sum: Optional[str] = None): @property def md5sum(self) -> str: """Lazily compute the checksum once needed.""" - if self._md5sum is None: - self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) + if not self.has_md5sum: + self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum() return self._md5sum + @md5sum.setter + def md5sum(self, value: str): + """Set the checksum.""" + if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): + raise ValueError("The checksum should be a 32-character long MD5 hash.") + self._md5sum = value + + @property + def has_md5sum(self) -> bool: + """Whether the md5sum for this class has been computed and stored.""" + return self._md5sum is not None + @computed_field @property def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: """Lazily compute the checksum once needed.""" - if self._zarr_md5sum_manifest is None and self._md5sum is None: - self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum(self.table, self.zarr_root_path) + if self._zarr_md5sum_manifest is None and not self.has_md5sum: + self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum() return self._zarr_md5sum_manifest @property @@ -254,7 +262,7 @@ def client(self): return self._client @property - def uses_zarr(self) -> str: + def uses_zarr(self) -> bool: """Whether any of the data in this dataset is stored in a Zarr Archive.""" return self.zarr_root_path is not None @@ -295,7 +303,7 @@ def zarr_root(self): saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path) if self._warn_about_remote_zarr: - saved_remote = saved_on_hub or not dmfs.is_local_path(self.zarr_root_path) + saved_remote = saved_on_hub or not Path(self.zarr_root_path).exists() if saved_remote: logger.warning( @@ -460,17 +468,18 @@ def to_json( # Copy over Zarr data to the destination if self.uses_zarr: - with tmp_attribute_change(self, "_warn_about_remote_zarr", False): - logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") + self._warn_about_remote_zarr = False - dest = zarr.open(new_zarr_root_path, "w") + logger.info(f"Copying Zarr archive to {new_zarr_root_path}. This may take a while.") - zarr.copy_store( - source=self.zarr_root.store.store, - dest=dest.store, - log=logger.debug, - if_exists=if_exists, - ) + dest = zarr.open(new_zarr_root_path, "w") + + zarr.copy_store( + source=self.zarr_root.store.store, + dest=dest.store, + log=logger.debug, + if_exists=if_exists, + ) self.table.to_parquet(table_path) with fsspec.open(dataset_path, "w") as f: @@ -499,8 +508,8 @@ def cache(self, cache_dir: Optional[str] = None, verify_checksum: bool = True) - self.zarr_root_path = dmfs.join(self.cache_dir, "data.zarr") self._zarr_root = None - if verify_checksum and self._md5sum is not None: - self.verify_checksum(md5sum=self._md5sum) + if verify_checksum: + self.verify_checksum() return self.cache_dir diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 9b81db08..7f813831 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -36,8 +36,10 @@ PolarisHubError, PolarisUnauthorizedError, ) +from polaris.utils.misc import should_verify_checksum from polaris.utils.types import ( AccessType, + ChecksumStrategy, HubOwner, IOMode, SupportedLicenseType, @@ -178,15 +180,6 @@ def _normalize_owner( return artifact_owner if isinstance(artifact_owner, HubOwner) else HubOwner(slug=artifact_owner) - @staticmethod - def _normalize_verify_checksum( - verify_checksum: Optional[bool], - dataset: Dataset, - ): - if verify_checksum is not None: - return verify_checksum - return dataset._md5sum is not None and not dataset.uses_zarr - def get_metadata_from_response(self, response: Response, key: str) -> Optional[str]: """Get custom metadata saved to the R2 object from the headers.""" key = f"{self.settings.custom_metadata_prefix}{key}" @@ -345,7 +338,10 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: return dataset_list def get_dataset( - self, owner: Union[str, HubOwner], name: str, verify_checksum: Optional[bool] = None + self, + owner: Union[str, HubOwner], + name: str, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> Dataset: """Load a dataset from the Polaris Hub. @@ -376,12 +372,11 @@ def get_dataset( response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) dataset = Dataset(**response) - verify_checksum = self._normalize_verify_checksum(verify_checksum, dataset) - if verify_checksum: + if should_verify_checksum(verify_checksum, dataset): dataset.verify_checksum() else: - dataset._md5sum = response["md5Sum"] + dataset.md5sum = response["md5Sum"] return dataset def open_zarr_file( @@ -437,15 +432,17 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: return benchmarks_list def get_benchmark( - self, owner: Union[str, HubOwner], name: str, verify_checksum: Optional[bool] = None + self, + owner: Union[str, HubOwner], + name: str, + verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkSpecification: """Load a benchmark from the Polaris Hub. Args: owner: The owner of the benchmark. Can be either a user or organization from the Polaris Hub. name: The name of the benchmark. - verify_checksum: Whether to use the checksum to verify the integrity of the dataset. If None, - will infer a practical default based on the dataset's storage location. + verify_checksum: Whether to use the checksum to verify the integrity of the benchmark. Returns: A `BenchmarkSpecification` instance, if it exists. @@ -469,12 +466,10 @@ def get_benchmark( benchmark = benchmark_cls(**response) - verify_checksum = self._normalize_verify_checksum(verify_checksum, benchmark.dataset) - - if verify_checksum: + if should_verify_checksum(verify_checksum, benchmark.dataset): benchmark.verify_checksum() else: - benchmark._md5sum = response["md5Sum"] + benchmark.md5sum = response["md5Sum"] return benchmark @@ -627,7 +622,11 @@ def upload_dataset( bucket_response = self.request( url=hub_response_body["url"], method=hub_response_body["method"], - headers={"Content-type": "application/vnd.apache.parquet", **hub_response_body["headers"]}, + headers={ + "Content-type": "application/vnd.apache.parquet", + **hub_response_body["headers"], + "Content-MD5": parquet_md5, + }, content=buffer.getvalue(), auth=None, timeout=timeout, # required for large size dataset diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index 9bc8ab85..f206a8e4 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -223,7 +223,12 @@ def pipe_file( hub_response_body = response.json() signed_url = hub_response_body["url"] - headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} + headers = { + "Content-Type": "application/octet-stream", + **hub_response_body["headers"], + # By adding this header, R2 will verify the MD5 checksum of the content on upload + "Content-MD5": self.polaris_client.get_metadata_from_response(response, "md5sum"), + } response = self.polaris_client.request( url=signed_url, diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 79ccf9db..49fea5bb 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -1,5 +1,4 @@ import json -from typing import Optional import fsspec from datamol.utils import fs @@ -10,9 +9,11 @@ ) from polaris.dataset import Dataset, create_dataset_from_file from polaris.hub.client import PolarisHubClient +from polaris.utils.misc import should_verify_checksum +from polaris.utils.types import ChecksumStrategy -def load_dataset(path: str, verify_checksum: Optional[bool] = None) -> Dataset: +def load_dataset(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr") -> Dataset: """ Loads a Polaris dataset. @@ -45,13 +46,13 @@ def load_dataset(path: str, verify_checksum: Optional[bool] = None) -> Dataset: dataset = create_dataset_from_file(path) # Verify checksum if requested - if PolarisHubClient._normalize_verify_checksum(verify_checksum, dataset): + if should_verify_checksum(verify_checksum, dataset): dataset.verify_checksum() return dataset -def load_benchmark(path: str, verify_checksum: Optional[bool] = None): +def load_benchmark(path: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr"): """ Loads a Polaris benchmark. @@ -88,7 +89,7 @@ def load_benchmark(path: str, verify_checksum: Optional[bool] = None): benchmark = cls.from_json(path) # Verify checksum if requested - if PolarisHubClient._normalize_verify_checksum(verify_checksum, benchmark.dataset): - benchmark.verify_checksum(md5sum=data["md5sum"]) + if should_verify_checksum(verify_checksum, benchmark.dataset): + benchmark.verify_checksum() return benchmark diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 98eb0be6..9a8199eb 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -1,6 +1,9 @@ -from typing import Any +from typing import TYPE_CHECKING, Any -from polaris.utils.types import SlugCompatibleStringType +from polaris.utils.types import ChecksumStrategy, SlugCompatibleStringType + +if TYPE_CHECKING: + from polaris.dataset import Dataset def listit(t: Any): @@ -16,3 +19,15 @@ def sluggify(sluggable: SlugCompatibleStringType): Converts a string to a slug-compatible string. """ return sluggable.lower().replace("_", "-") + + +def should_verify_checksum(strategy: ChecksumStrategy, dataset: "Dataset") -> bool: + """ + Determines whether a checksum should be verified. + """ + if strategy == "ignore": + return False + elif strategy == "verify": + return True + else: + return not dataset.uses_zarr diff --git a/polaris/utils/types.py b/polaris/utils/types.py index 0c67484b..b5d5913b 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -107,6 +107,11 @@ Type to specify which action to take when encountering existing files within a Zarr archive. """ +ChecksumStrategy: TypeAlias = Literal["verify", "verify_unless_zarr", "ignore"] +""" +Type to specify which action to take to verify the data integrity of an artifact through a checksum. +""" + class HubOwner(BaseModel): """An owner of an artifact on the Polaris Hub From dfe03bc18f9a85c20c5113b6a59e201c7efd94b3 Mon Sep 17 00:00:00 2001 From: cwognum Date: Thu, 11 Jul 2024 12:34:45 -0400 Subject: [PATCH 25/29] Fixed import error --- polaris/hub/client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index d2c36e5f..8af7f0fa 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -179,8 +179,8 @@ def _base_request_to_hub(self, url: str, method: str, **kwargs): pass return response - - def get_metadata_from_response(self, response: Response, key: str) -> Optional[str]: + + def get_metadata_from_response(self, response: Response, key: str) -> str | None: """Get custom metadata saved to the R2 object from the headers.""" key = f"{self.settings.custom_metadata_prefix}{key}" return response.headers.get(key) From 53303161806a1a999d3c64b6012c02cbe138c8a0 Mon Sep 17 00:00:00 2001 From: cwognum Date: Thu, 11 Jul 2024 16:08:16 -0400 Subject: [PATCH 26/29] Remove Content-MD5 header from client --- polaris/hub/client.py | 6 +----- polaris/hub/polarisfs.py | 7 +------ 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 4d581dce..62fad481 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -529,11 +529,7 @@ def upload_dataset( bucket_response = self.request( url=hub_response_body["url"], method=hub_response_body["method"], - headers={ - "Content-type": "application/vnd.apache.parquet", - **hub_response_body["headers"], - "Content-MD5": parquet_md5, - }, + headers={"Content-type": "application/vnd.apache.parquet", **hub_response_body["headers"]}, content=buffer.getvalue(), auth=None, timeout=timeout, # required for large size dataset diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py index f206a8e4..9bc8ab85 100644 --- a/polaris/hub/polarisfs.py +++ b/polaris/hub/polarisfs.py @@ -223,12 +223,7 @@ def pipe_file( hub_response_body = response.json() signed_url = hub_response_body["url"] - headers = { - "Content-Type": "application/octet-stream", - **hub_response_body["headers"], - # By adding this header, R2 will verify the MD5 checksum of the content on upload - "Content-MD5": self.polaris_client.get_metadata_from_response(response, "md5sum"), - } + headers = {"Content-Type": "application/octet-stream", **hub_response_body["headers"]} response = self.polaris_client.request( url=signed_url, From 83162d49459365eb765be92e1798b3d93b35ca61 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Thu, 11 Jul 2024 20:06:05 -0400 Subject: [PATCH 27/29] Addressed feedback from PR --- polaris/_mixins.py | 70 ++++++++++++++++++++++++++++++++++ polaris/benchmark/_base.py | 51 ++----------------------- polaris/dataset/_dataset.py | 75 +++++++------------------------------ polaris/utils/httpx.py | 35 ----------------- pyproject.toml | 6 +++ tests/test_benchmark.py | 14 +++++++ tests/test_dataset.py | 14 +++++++ 7 files changed, 121 insertions(+), 144 deletions(-) create mode 100644 polaris/_mixins.py delete mode 100644 polaris/utils/httpx.py diff --git a/polaris/_mixins.py b/polaris/_mixins.py new file mode 100644 index 00000000..e4fb2794 --- /dev/null +++ b/polaris/_mixins.py @@ -0,0 +1,70 @@ +import abc + +from loguru import logger +from pydantic import BaseModel, PrivateAttr, computed_field + +from polaris.utils.errors import PolarisChecksumError + + +class ChecksumMixin(BaseModel, abc.ABC): + """ + Mixin class to add checksum functionality to a class. + """ + + _md5sum: str | None = PrivateAttr(None) + + @abc.abstractmethod + def _compute_checksum(self) -> str: + """Compute the checksum and return it along with manifest.""" + raise NotImplementedError + + @computed_field + @property + def md5sum(self) -> str: + """Lazily compute the checksum once needed.""" + if not self.has_md5sum: + logger.info("Computing the checksum. This can be slow for large datasets.") + self.md5sum = self._compute_checksum() + return self._md5sum + + @md5sum.setter + def md5sum(self, value: str): + """Set the checksum.""" + if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): + raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") + self._md5sum = value + + @property + def has_md5sum(self) -> bool: + """Whether the md5sum for this class has been computed and stored.""" + return self._md5sum is not None + + def verify_checksum(self, md5sum: str | None = None): + """ + Recomputes the checksum and verifies whether it matches the stored checksum. + + Warning: Slow operation + This operation can be slow for large datasets. + + Info: Only works for locally stored datasets + The checksum verification only works for datasets that are stored locally in its entirety. + We don't have to verify the checksum for datasets stored on the Hub, as the Hub will do this on upload. + And if you're streaming the data from the Hub, we will check the checksum of each chunk on download. + """ + if md5sum is None: + md5sum = self._md5sum + if md5sum is None: + logger.warning( + "No checksum to verify against. Specify either the md5sum parameter or " + "store the checksum in the dataset.md5sum attribute." + ) + return + + # Recompute the checksum + logger.info("To verify the checksum, we need to recompute it. This can be slow for large datasets.") + self.md5sum = self._compute_checksum() + + if self.md5sum != md5sum: + raise PolarisChecksumError( + f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" + ) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 982618c9..7af669fb 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -6,10 +6,8 @@ import numpy as np import pandas as pd from datamol.utils import fs -from loguru import logger from pydantic import ( Field, - PrivateAttr, ValidationInfo, computed_field, field_serializer, @@ -19,12 +17,13 @@ from sklearn.utils.multiclass import type_of_target from polaris._artifact import BaseArtifactModel +from polaris._mixins import ChecksumMixin from polaris.dataset import Dataset, Subset from polaris.evaluate import BenchmarkResults, Metric, ResultsType from polaris.hub.settings import PolarisHubSettings from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError +from polaris.utils.errors import InvalidBenchmarkError from polaris.utils.misc import listit from polaris.utils.types import ( AccessType, @@ -38,7 +37,7 @@ ColumnsType = Union[str, list[str]] -class BenchmarkSpecification(BaseArtifactModel): +class BenchmarkSpecification(BaseArtifactModel, ChecksumMixin): """This class wraps a [`Dataset`][polaris.dataset.Dataset] with additional data to specify the evaluation logic. @@ -109,9 +108,6 @@ class BenchmarkSpecification(BaseArtifactModel): default_factory=dict, validate_default=True ) - # Private attributes - _md5sum: Optional[str] = PrivateAttr(None) - @field_validator("dataset") def _validate_dataset(cls, v): """ @@ -294,47 +290,6 @@ def _compute_checksum(self): checksum = hash_fn.hexdigest() return checksum - def verify_checksum(self, md5sum: Optional[str] = None): - """ - Recomputes the checksum and verifies whether it matches the stored checksum. - """ - if md5sum is None: - md5sum = self._md5sum - if md5sum is None: - logger.warning( - "No checksum to verify against. Specify either the md5sum parameter or " - "store the checksum in the benchmark.md5sum attribute. Skipping!" - ) - return - - # Temporarily reset - # Calling self.md5sum will recompute the checksum and set it again - self._md5sum = None - if self.md5sum != md5sum: - raise PolarisChecksumError( - f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" - ) - - @computed_field - @property - def md5sum(self) -> Optional[str]: - """Lazily compute the checksum once needed.""" - if not self.has_md5sum: - self._md5sum = self._compute_checksum() - return self._md5sum - - @md5sum.setter - def md5sum(self, value: str): - """Set the checksum.""" - if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): - raise ValueError("The checksum should be a 32-character long MD5 hash.") - self._md5sum = value - - @property - def has_md5sum(self) -> Optional[str]: - """Lazily compute the checksum once needed.""" - return self._md5sum is not None - @computed_field @property def n_train_datapoints(self) -> int: diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 6388dd48..745485b3 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -20,13 +20,14 @@ ) from polaris._artifact import BaseArtifactModel +from polaris._mixins import ChecksumMixin from polaris.dataset._adapters import Adapter from polaris.dataset._column import ColumnAnnotation from polaris.dataset.zarr import MemoryMappedDirectoryStore, ZarrFileChecksum, compute_zarr_checksum from polaris.hub.polarisfs import PolarisFileSystem from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.dict2html import dict2html -from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError +from polaris.utils.errors import InvalidDatasetError from polaris.utils.types import ( AccessType, HttpUrlString, @@ -41,7 +42,7 @@ _INDEX_SEP = "#" -class Dataset(BaseArtifactModel): +class Dataset(BaseArtifactModel, ChecksumMixin): """Basic data-model for a Polaris dataset, implemented as a [Pydantic](https://docs.pydantic.dev/latest/) model. At its core, a dataset in Polaris is a tabular data structure that stores data-points in a row-wise manner. @@ -59,8 +60,6 @@ class Dataset(BaseArtifactModel): default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data for specific columns. zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to. - md5sum: The checksum is used to verify the version of the dataset specification. If specified, it will - raise an error if the specified checksum doesn't match the computed checksum. readme: Markdown text that can be used to provide a formatted description of the dataset. If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI as it provides a rich text editor for writing markdown. @@ -73,7 +72,6 @@ class Dataset(BaseArtifactModel): Raises: InvalidDatasetError: If the dataset does not conform to the Pydantic data-model specification. - PolarisChecksumError: If the specified checksum does not match the computed checksum. """ # Public attributes @@ -185,69 +183,24 @@ def _compute_checksum(self): hash_fn.update(table_hash) # If the Zarr archive exists, we hash its contents too. - zarr_md5sum_manifest = None - if self.zarr_root_path is not None: - zarr_hash, zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) + if self.uses_zarr: + zarr_hash, self._zarr_md5sum_manifest = compute_zarr_checksum(self.zarr_root_path) hash_fn.update(zarr_hash.encode()) checksum = hash_fn.hexdigest() - return checksum, zarr_md5sum_manifest - - def verify_checksum(self, md5sum: Optional[str] = None): - """ - Recomputes the checksum and verifies whether it matches the stored checksum. - - Warning: Slow operation - This operation can be slow for large datasets. - - Info: Only works for locally stored datasets - The checksum verification only works for datasets that are stored locally in its entirety. - We don't have to verify the checksum for datasets stored on the Hub, as the Hub will do this on upload. - And if you're streaming the data from the Hub, we will check the checksum of each chunk on download. - """ - if md5sum is None: - md5sum = self._md5sum - if md5sum is None: - logger.warning( - "No checksum to verify against. Specify either the md5sum parameter or " - "store the checksum in the dataset.md5sum attribute." - ) - return - - # Temporarily reset - # Calling self.md5sum will recompute the checksum and set it again - self._md5sum = None - if self.md5sum != md5sum: - raise PolarisChecksumError( - f"The specified checksum {md5sum} does not match the computed checksum {self.md5sum}" - ) - - @computed_field - @property - def md5sum(self) -> str: - """Lazily compute the checksum once needed.""" - if not self.has_md5sum: - self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum() - return self._md5sum - - @md5sum.setter - def md5sum(self, value: str): - """Set the checksum.""" - if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): - raise ValueError("The checksum should be a 32-character long MD5 hash.") - self._md5sum = value - - @property - def has_md5sum(self) -> bool: - """Whether the md5sum for this class has been computed and stored.""" - return self._md5sum is not None + return checksum @computed_field @property def zarr_md5sum_manifest(self) -> List[ZarrFileChecksum]: - """Lazily compute the checksum once needed.""" - if self._zarr_md5sum_manifest is None and not self.has_md5sum: - self._md5sum, self._zarr_md5sum_manifest = self._compute_checksum() + """ + The Zarr Checksum manifest stores the checksums of all files in a Zarr archive. + If the dataset doesn't use Zarr, this will simply return an empty list. + """ + if len(self._zarr_md5sum_manifest) == 0 and not self.has_md5sum: + # The manifest is set as an instance variable + # as a side-effect of the compute_checksum method + self.md5sum = self._compute_checksum() return self._zarr_md5sum_manifest @property diff --git a/polaris/utils/httpx.py b/polaris/utils/httpx.py deleted file mode 100644 index 8c2a3709..00000000 --- a/polaris/utils/httpx.py +++ /dev/null @@ -1,35 +0,0 @@ -from httpx import Response - - -def _log_response(response: Response) -> str: - """ - Fully logs a request/response pair for HTTPX. - Used for debugging purposes. - """ - req_prefix = "< " - res_prefix = "> " - request = response.request - output = [f"{req_prefix}{request.method} {request.url}"] - - for name, value in request.headers.items(): - output.append(f"{req_prefix}{name}: {value}") - - output.append(req_prefix) - - if isinstance(request.content, (str, bytes)): - output.append(f"{req_prefix}{request.content}") - else: - output.append("<< Request body is not a string-like type >>") - - output.append("") - - output.append(f"{res_prefix} {response.status_code} {response.reason_phrase}") - - for name, value in response.headers.items(): - output.append(f"{res_prefix}{name}: {value}") - - output.append(res_prefix) - - output.append(f"{res_prefix}{response.text}") - - return "\n".join(output) diff --git a/pyproject.toml b/pyproject.toml index e640063c..6f42c52a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,8 +113,14 @@ data_file = ".coverage/coverage" omit = [ "polaris/__init__.py", "polaris/_version.py", + # We cannot yet test the interaction with the Hub. + # See e.g. https://github.com/polaris-hub/polaris/issues/30 "polaris/hub/client.py", + "polaris/hub/external_auth_client.py", + "polaris/hub/oauth2.py", "polaris/hub/settings.py", + "polaris/hub/polarisfs.py", + "polaris/hub/__init__.py", "polaris/hub/__init__.py", ] diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 45e6814a..b959739c 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -188,3 +188,17 @@ def _check_for_failure(_kwargs): kwargs["md5sum"] = None dataset = cls(**kwargs) assert dataset.md5sum is not None + + +def test_setting_an_invalid_checksum(test_single_task_benchmark): + """Test whether setting an invalid checksum raises an error.""" + with pytest.raises(ValueError): + test_single_task_benchmark.md5sum = "invalid" + + +def test_checksum_verification(test_single_task_benchmark): + """Test whether setting an invalid checksum raises an error.""" + test_single_task_benchmark.verify_checksum() + test_single_task_benchmark.md5sum = "0" * 32 + with pytest.raises(ValueError): + test_single_task_benchmark.verify_checksum() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 7369f22a..db4336e1 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -163,3 +163,17 @@ def test_dataset_in_memory_optimization(zarr_archive, tmpdir): d2 = perf_counter() - t2 assert d2 < d1 + + +def test_setting_an_invalid_checksum(test_dataset): + """Test whether setting an invalid checksum raises an error.""" + with pytest.raises(ValueError): + test_dataset.md5sum = "invalid" + + +def test_checksum_verification(test_dataset): + """Test whether setting an invalid checksum raises an error.""" + test_dataset.verify_checksum() + test_dataset.md5sum = "0" * 32 + with pytest.raises(ValueError): + test_dataset.verify_checksum() From ce3fce4513e125241c5e23804d89476e44c23543 Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 12 Jul 2024 10:43:26 -0400 Subject: [PATCH 28/29] Use RE to match checksum Co-authored-by: Julien St-Laurent --- polaris/_mixins.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/polaris/_mixins.py b/polaris/_mixins.py index e4fb2794..33948b3b 100644 --- a/polaris/_mixins.py +++ b/polaris/_mixins.py @@ -30,7 +30,7 @@ def md5sum(self) -> str: @md5sum.setter def md5sum(self, value: str): """Set the checksum.""" - if len(value) != 32 or not all(c in "0123456789abcdef" for c in value): + if not re.fullmatch(r"^[a-f0-9]{32}$", value): raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") self._md5sum = value From fba8ce38353c74a3d1d2c8d8e7a4142176f67a1b Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Fri, 12 Jul 2024 10:47:27 -0400 Subject: [PATCH 29/29] Clarify docs --- polaris/_mixins.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/polaris/_mixins.py b/polaris/_mixins.py index 33948b3b..8fccac35 100644 --- a/polaris/_mixins.py +++ b/polaris/_mixins.py @@ -1,4 +1,5 @@ import abc +import re from loguru import logger from pydantic import BaseModel, PrivateAttr, computed_field @@ -15,7 +16,7 @@ class ChecksumMixin(BaseModel, abc.ABC): @abc.abstractmethod def _compute_checksum(self) -> str: - """Compute the checksum and return it along with manifest.""" + """Compute the checksum of the dataset.""" raise NotImplementedError @computed_field