Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from polaris.utils import fs
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError
from polaris.utils.types import AccessType, HubOwner, TimeoutTypes, IOMode
from polaris.utils.types import AccessType, HubOwner, IOMode, TimeoutTypes

_HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]"

Expand Down Expand Up @@ -304,12 +304,13 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
dataset_list = [bm["artifactId"] for bm in response["data"]]
return dataset_list

def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset:
def get_dataset(self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True) -> Dataset:
"""Load a dataset from the Polaris Hub.

Args:
owner: The owner of the dataset. Can be either a user or organization from the Polaris Hub.
name: The name of the dataset.
verify_checksum: Whether to use the checksum to verify the integrity of the dataset.

Returns:
A `Dataset` instance, if it exists.
Expand All @@ -331,6 +332,9 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset:

response["table"] = self._load_from_signed_url(url=url, headers=headers, load_fn=pd.read_parquet)

if not verify_checksum:
response.pop("md5Sum", None)

return Dataset(**response)

def open_zarr_file(
Expand Down Expand Up @@ -377,12 +381,15 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:
benchmarks_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response["data"]]
return benchmarks_list

def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpecification:
def get_benchmark(
self, owner: Union[str, HubOwner], name: str, verify_checksum: bool = True
) -> BenchmarkSpecification:
"""Load a benchmark from the Polaris Hub.

Args:
owner: The owner of the benchmark. Can be either a user or organization from the Polaris Hub.
name: The name of the benchmark.
verify_checksum: Whether to use the checksum to verify the integrity of the dataset.

Returns:
A `BenchmarkSpecification` instance, if it exists.
Expand All @@ -392,7 +399,9 @@ def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpec

# TODO (cwognum): Currently, the benchmark endpoints do not return the owner info for the underlying dataset.
# TODO (jstlaurent): Use the same owner for now, until the benchmark returns a better dataset entity
response["dataset"] = self.get_dataset(owner, response["dataset"]["name"])
response["dataset"] = self.get_dataset(
owner, response["dataset"]["name"], verify_checksum=verify_checksum
)

# TODO (cwognum): As we get more complicated benchmarks, how do we still find the right subclass?
# Maybe through structural pattern matching, introduced in Py3.10, or Pydantic's discriminated unions?
Expand All @@ -401,6 +410,10 @@ def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpec
if len(response["targetCols"]) == 1
else MultiTaskBenchmarkSpecification
)

if not verify_checksum:
response.pop("md5Sum", None)

return benchmark_cls(**response)

def upload_results(
Expand Down
8 changes: 4 additions & 4 deletions polaris/loader/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from polaris.utils import fs


def load_dataset(path: str) -> Dataset:
def load_dataset(path: str, verify_checksum: bool = True) -> Dataset:
"""
Loads a Polaris dataset.

Expand All @@ -35,14 +35,14 @@ def load_dataset(path: str) -> Dataset:
if not is_file:
# Load from the Hub
client = PolarisHubClient()
return client.get_dataset(*path.split("/"))
return client.get_dataset(*path.split("/"), verify_checksum=verify_checksum)

if extension == "json":
return Dataset.from_json(path)
return create_dataset_from_file(path)


def load_benchmark(path: str):
def load_benchmark(path: str, verify_checksum: bool = True):
"""
Loads a Polaris benchmark.

Expand All @@ -66,7 +66,7 @@ def load_benchmark(path: str):
if not is_file:
# Load from the Hub
client = PolarisHubClient()
return client.get_benchmark(*path.split("/"))
return client.get_benchmark(*path.split("/"), verify_checksum=verify_checksum)

with fsspec.open(path, "r") as fd:
data = json.load(fd)
Expand Down