From e0c6e2f82b273f44357af6f4b9ebb26e6388e0f8 Mon Sep 17 00:00:00 2001 From: cwognum Date: Wed, 20 Mar 2024 18:31:55 -0400 Subject: [PATCH] Added an option to disable the checksum verification --- polaris/hub/client.py | 21 +++++++++++++++++---- polaris/loader/load.py | 8 ++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 935cd40d..2aff2cb7 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -32,7 +32,7 @@ from polaris.utils import fs from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError -from polaris.utils.types import AccessType, HubOwner, TimeoutTypes, IOMode +from polaris.utils.types import AccessType, HubOwner, IOMode, TimeoutTypes _HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]" @@ -304,12 +304,13 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: dataset_list = [bm["artifactId"] for bm in response["data"]] return dataset_list - def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset: + def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True) -> Dataset: """Load a dataset from the Polaris Hub. Args: owner: The owner of the dataset. Can be either a user or organization from the Polaris Hub. name: The name of the dataset. + verify_checksum: Whether to use the checksum to verify the integrity of the dataset. Returns: A `Dataset` instance, if it exists. @@ -331,6 +332,9 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset: response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet) + if not verify_checksum: + response.pop("md5Sum", None) + return Dataset(**response) def open_zarr_file( @@ -377,12 +381,15 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: benchmarks_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response["data"]] return benchmarks_list - def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpecification: + def get_benchmark( + self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True + ) -> BenchmarkSpecification: """Load a benchmark from the Polaris Hub. Args: owner: The owner of the benchmark. Can be either a user or organization from the Polaris Hub. name: The name of the benchmark. + verify_checksum: Whether to use the checksum to verify the integrity of the dataset. Returns: A `BenchmarkSpecification` instance, if it exists. @@ -392,7 +399,9 @@ def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpec # TODO (cwognum): Currently, the benchmark endpoints do not return the owner info for the underlying dataset. # TODO (jstlaurent): Use the same owner for now, until the benchmark returns a better dataset entity - response["dataset"] = self.get_dataset(owner, response["dataset"]["name"]) + response["dataset"] = self.get_dataset( + owner, response["dataset"]["name"], verify_checksum=verify_checksum + ) # TODO (cwognum): As we get more complicated benchmarks, how do we still find the right subclass? # Maybe through structural pattern matching, introduced in Py3.10, or Pydantic's discriminated unions? @@ -401,6 +410,10 @@ def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpec if len(response["targetCols"]) == 1 else MultiTaskBenchmarkSpecification ) + + if not verify_checksum: + response.pop("md5Sum", None) + return benchmark_cls(**response) def upload_results( diff --git a/polaris/loader/load.py b/polaris/loader/load.py index 2bec36dd..4f4d10c8 100644 --- a/polaris/loader/load.py +++ b/polaris/loader/load.py @@ -11,7 +11,7 @@ from polaris.utils import fs -def load_dataset(path: str) -> Dataset: +def load_dataset(path: str, verify_checksum: bool = True) -> Dataset: """ Loads a Polaris dataset. @@ -35,14 +35,14 @@ def load_dataset(path: str) -> Dataset: if not is_file: # Load from the Hub client = PolarisHubClient() - return client.get_dataset(*path.split("/")) + return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum) if extension == "json": return Dataset.from_json(path) return create_dataset_from_file(path) -def load_benchmark(path: str): +def load_benchmark(path: str, verify_checksum: bool = True): """ Loads a Polaris benchmark. @@ -66,7 +66,7 @@ def load_benchmark(path: str): if not is_file: # Load from the Hub client = PolarisHubClient() - return client.get_benchmark(*path.split("/")) + return client.get_benchmark(*path.split("/"), verify_checksum=verify_checksum) with fsspec.open(path, "r") as fd: data = json.load(fd)