From be3ffd59fa1a5ee397ac27a6a8225f7a542c3f0f Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 14:20:43 -0400 Subject: [PATCH 01/20] remove authentication for get methods --- polaris/dataset/_dataset.py | 6 +-- polaris/dataset/_dataset_v2.py | 6 +-- polaris/hub/client.py | 78 ++++++++++++++-------------------- polaris/hub/settings.py | 2 + polaris/hub/storage.py | 56 +++++++++--------------- 5 files changed, 60 insertions(+), 88 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 30383749..62aa229b 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -140,10 +140,10 @@ def load_zarr_root_from_hub(self): """ from polaris.hub.client import PolarisHubClient from polaris.hub.storage import StorageSession + import zarr - with PolarisHubClient() as client: - with StorageSession(client, "read", self.urn) as storage: - return zarr.open_consolidated(store=storage.store("extension")) + store = StorageSession.store(self.zarr_root_path) + return zarr.open_consolidated(store=store) @computed_field @property diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 8a757bdd..73e75df1 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -119,10 +119,10 @@ def load_zarr_root_from_hub(self): """ from polaris.hub.client import PolarisHubClient from polaris.hub.storage import StorageSession + import zarr - with PolarisHubClient() as client: - with StorageSession(client, "read", self.urn) as storage: - return zarr.open_consolidated(store=storage.store("root")) + store = StorageSession.store(self.zarr_root_path) + return zarr.open_consolidated(store=store) @property def zarr_manifest_path(self) -> str: diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 439ac4af..bb54d344 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -119,12 +119,7 @@ def __init__( ) def __enter__(self: Self) -> Self: - """ - When used as a context manager, automatically check that authentication is valid. - """ super().__enter__() - if not self.ensure_active_token(): - raise PolarisUnauthorizedError() return self @property @@ -263,10 +258,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: A list of dataset names in the format `owner/dataset_slug`. """ with track_progress(description="Fetching datasets", total=1): - # Step 1: Fetch enough v2 datasets to cover the offset and limit - v2_json_response = self._base_request_to_hub( - url="/v2/dataset", method="GET", params={"limit": limit, "offset": offset} - ).json() + v2_json_response = self.request(url="/v2/dataset", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() v2_data = v2_json_response["data"] v2_datasets = [dataset["artifactId"] for dataset in v2_data] @@ -277,9 +269,10 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: # Step 2: Calculate the remaining limit and fetch v1 datasets remaining_limit = max(0, limit - len(v2_datasets)) - v1_json_response = self._base_request_to_hub( + v1_json_response = self.request( url="/v1/dataset", method="GET", + withhold_token=True, params={ "limit": remaining_limit, "offset": max(0, offset - v2_json_response["metadata"]["total"]), @@ -335,23 +328,19 @@ def _get_v1_dataset( A `Dataset` instance, if it exists. """ url = f"/v1/dataset/{owner}/{slug}" - response = self._base_request_to_hub(url=url, method="GET") + response = self.request(url=url, method="GET", withhold_token=True) response_data = response.json() - # Disregard the Zarr root in the response. We'll get it from the storage token instead. - response_data.pop("zarrRootPath", None) + # Prefer table_path and zarr_path from response metadata if available + metadata = response_data.get("metadata", {}) + table_path = metadata.get("table_path") + zarr_path = metadata.get("zarr_path") # Load the dataset table and optional Zarr archive with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage: - table = pd.read_parquet(BytesIO(storage.get_file("root"))) - zarr_root_path = storage.paths.extension - - if zarr_root_path is not None: - # For V1 datasets, the Zarr Root is optional. - # It should be None if the dataset does not use pointer columns - zarr_root_path = str(zarr_root_path) - - dataset = DatasetV1(table=table, zarr_root_path=zarr_root_path, **response_data) + table = pd.read_parquet(BytesIO(storage.get_file(table_path))) + + dataset = DatasetV1(table=table, zarr_root_path=zarr_path, **response_data) md5sum = response_data["md5Sum"] if dataset.should_verify_checksum(verify_checksum): @@ -364,17 +353,14 @@ def _get_v1_dataset( def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: """""" url = f"/v2/dataset/{owner}/{slug}" - response = self._base_request_to_hub(url=url, method="GET") + response = self.request(url=url, method="GET", withhold_token=True) response_data = response.json() - # Disregard the Zarr root in the response. We'll get it from the storage token instead. - response_data.pop("zarrRootPath", None) + metadata = response_data.get("metadata", {}) + zarr_path = metadata.get("zarr_path") + # For v2 datasets, the zarr_path always exists - # Load the Zarr archive - with StorageSession(self, "read", DatasetV2.urn_for(owner, slug)) as storage: - zarr_root_path = str(storage.paths.root) - - dataset = DatasetV2(zarr_root_path=zarr_root_path, **response_data) + dataset = DatasetV2(zarr_root_path=zarr_path, **response_data) return dataset def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: @@ -390,9 +376,7 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: """ with track_progress(description="Fetching benchmarks", total=1): # Step 1: Fetch enough v2 benchmarks to cover the offset and limit - v2_json_response = self._base_request_to_hub( - url="/v2/benchmark", method="GET", params={"limit": limit, "offset": offset} - ).json() + v2_json_response = self.request(url="/v2/benchmark", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() v2_data = v2_json_response["data"] v2_benchmarks = [benchmark["artifactId"] for benchmark in v2_data] @@ -402,10 +386,10 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: # Step 2: Calculate the remaining limit and fetch v1 benchmarks remaining_limit = max(0, limit - len(v2_benchmarks)) - - v1_json_response = self._base_request_to_hub( + v1_json_response = self.request( url="/v1/benchmark", method="GET", + withhold_token=True, params={ "limit": remaining_limit, "offset": max(0, offset - v2_json_response["metadata"]["total"]), @@ -449,7 +433,7 @@ def _get_v1_benchmark( slug: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkV1Specification: - response = self._base_request_to_hub(url=f"/v1/benchmark/{owner}/{slug}", method="GET") + response = self.request(url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True) response_data = response.json() # TODO (jstlaurent): response["dataset"]["artifactId"] is the owner/name unique identifier, @@ -478,7 +462,7 @@ def _get_v1_benchmark( return benchmark def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Specification: - response = self._base_request_to_hub(url=f"/v2/benchmark/{owner}/{slug}", method="GET") + response = self.request(url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True) response_data = response.json() response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/")) @@ -511,6 +495,8 @@ def upload_results( results: The results to upload. owner: Which Hub user or organization owns the artifact. Takes precedence over `results.owner`. """ + if not self.ensure_active_token(): + raise PolarisUnauthorizedError() with track_progress(description="Uploading results", total=1) as (progress, task): # Get the serialized model data-structure results.owner = HubOwner.normalize(owner or results.owner) @@ -607,15 +593,13 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: A `CompetitionSpecification` instance, if it exists. """ url = f"/v1/competition/{artifact_id}" - response = self._base_request_to_hub(url=url, method="GET") + response = self.request(url=url, method="GET", withhold_token=True) response_data = response.json() - with StorageSession( - self, "read", CompetitionSpecification.urn_for(*artifact_id.split("/")) - ) as storage: - zarr_root_path = str(storage.paths.root) + metadata = response_data.get("metadata", {}) + zarr_path = metadata.get("zarr_path") - return CompetitionSpecification(zarr_root_path=zarr_root_path, **response_data) + return CompetitionSpecification(zarr_root_path=zarr_path, **response_data) def submit_competition_predictions( self, @@ -656,8 +640,8 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: A list of models names in the format `owner/model_slug`. """ with track_progress(description="Fetching models", total=1): - json_response = self._base_request_to_hub( - url="/v2/model", method="GET", params={"limit": limit, "offset": offset} + json_response = self.request( + url="/v2/model", method="GET", withhold_token=True, params={"limit": limit, "offset": offset} ).json() models = [model["artifactId"] for model in json_response["data"]] @@ -665,7 +649,7 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: def get_model(self, artifact_id: str) -> Model: url = f"/v2/model/{artifact_id}" - response = self._base_request_to_hub(url=url, method="GET") + response = self.request(url=url, method="GET", withhold_token=True) response_data = response.json() return Model(**response_data) @@ -693,6 +677,8 @@ def upload_model( owner: Which Hub user or organization owns the artifact. Takes precedence over `model.owner`. parent_artifact_id: The `owner/slug` of the parent model, if uploading a new version of a model. """ + if not self.ensure_active_token(): + raise PolarisUnauthorizedError() with track_progress(description="Uploading model", total=1) as (progress, task): # Get the serialized model data-structure model.owner = HubOwner.normalize(owner or model.owner) diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index 042b7d59..7950446f 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -30,6 +30,7 @@ class PolarisHubSettings(BaseSettings): A default value is generated based on the Hub URL, and this should not need to be overridden. username: The username for the Polaris Hub, for the optional password-based authentication. password: The password for the specified username. + public_artifact_url: The base URL for directly accessing public artifacts without authentication. """ # Configuration of the pydantic model @@ -41,6 +42,7 @@ class PolarisHubSettings(BaseSettings): hub_url: HttpUrlString = "https://polarishub.io/" api_url: HttpUrlString | None = None custom_metadata_prefix: str = "X-Amz-Meta-" + public_artifact_url: HttpUrlString = "https://data.polarishub.io/" # Hub authentication settings hub_token_url: HttpUrlString | None = None diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 4e225e52..d0b83109 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -18,6 +18,7 @@ from zarr.context import Context from zarr.storage import Store from zarr.util import buffer_size +import fsspec from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token from polaris.utils.context import track_progress @@ -476,15 +477,19 @@ class StorageSession(OAuth2Client): """ A context manager for managing a storage session, with token exchange and token refresh capabilities. Each session is associated with a specific scope and resource. + The `mode` parameter controls whether authentication is required: + - For 'read' mode, if a public_artifact_url is set, authentication is skipped. + - For 'write' mode, authentication is always required. """ polaris_protocol = "polarisfs" token_auth_class = StorageTokenAuth - def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn): + def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn, mode: str = "read"): self.hub_client = hub_client self.resource = resource + self.mode = mode super().__init__( # OAuth2Client @@ -496,8 +501,9 @@ def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn): cert=hub_client.settings.ca_bundle, ) - def __enter__(self) -> Self: - self.ensure_active_token() + def __enter__(self) -> Self: # Only skip authentication for reads with a public URL + if not (self.mode == "read"): + self.ensure_active_token() return self def _prepare_token_endpoint_body(self, body, grant_type, **kwargs) -> str: @@ -581,42 +587,20 @@ def set_file(self, path: str, value: bytes | bytearray): endpoint_url=storage_data.endpoint, content_type=content_type, ) - store[relative_path.name] = value + # Use StorageSession with mode='write' for write operations + with StorageSession(self.hub_client, "write", self.resource, mode="write"): + store[relative_path.name] = value def get_file(self, path: str) -> bytes | bytearray: """ Get the value at the given path. """ - if path not in self.paths.files: - raise NotImplementedError( - f"{type(self.paths).__name__} only supports these files: {self.paths.files}." - ) - - relative_path = self._relative_path(getattr(self.paths, path)) - - storage_data = self.token.extra_data - store = S3Store( - path=relative_path.parent, - access_key=storage_data.key, - secret_key=storage_data.secret, - token=f"jwt/{self.token.access_token}", - endpoint_url=storage_data.endpoint, - ) - return store[relative_path.name] + # The path is now a full URL, so we use it directly + with fsspec.open(path, mode='rb') as f: + return f.read() - def store(self, path: str) -> S3Store: - if path not in self.paths.stores: - raise NotImplementedError( - f"{type(self.paths).__name__} only supports these stores: {self.paths.stores}." - ) - - relative_path = self._relative_path(getattr(self.paths, path)) - - storage_data = self.token.extra_data - return S3Store( - path=relative_path, - access_key=storage_data.key, - secret_key=storage_data.secret, - token=f"jwt/{self.token.access_token}", - endpoint_url=storage_data.endpoint, - ) + def store(self, path: str): + """ + Return a fsspec mapper for the given path, using the provided path as a full URL. + """ + return fsspec.get_mapper(path) From 5014dc43b8456014ce292cf1694d46641a438efc Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 14:37:20 -0400 Subject: [PATCH 02/20] ruff fixes --- polaris/dataset/_dataset.py | 2 -- polaris/dataset/_dataset_v2.py | 1 - 2 files changed, 3 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 62aa229b..3b3fd3a5 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -7,7 +7,6 @@ import fsspec import numpy as np import pandas as pd -import zarr from datamol.utils import fs as dmfs from pydantic import PrivateAttr, computed_field, field_validator, model_validator from typing_extensions import Self, deprecated @@ -138,7 +137,6 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.client import PolarisHubClient from polaris.hub.storage import StorageSession import zarr diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 73e75df1..5b097be3 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -117,7 +117,6 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.client import PolarisHubClient from polaris.hub.storage import StorageSession import zarr From 58adbfb11cb78765b5eae5bf76fb7a58901b4daf Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 14:40:38 -0400 Subject: [PATCH 03/20] remove public url field --- polaris/hub/settings.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index 7950446f..042b7d59 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -30,7 +30,6 @@ class PolarisHubSettings(BaseSettings): A default value is generated based on the Hub URL, and this should not need to be overridden. username: The username for the Polaris Hub, for the optional password-based authentication. password: The password for the specified username. - public_artifact_url: The base URL for directly accessing public artifacts without authentication. """ # Configuration of the pydantic model @@ -42,7 +41,6 @@ class PolarisHubSettings(BaseSettings): hub_url: HttpUrlString = "https://polarishub.io/" api_url: HttpUrlString | None = None custom_metadata_prefix: str = "X-Amz-Meta-" - public_artifact_url: HttpUrlString = "https://data.polarishub.io/" # Hub authentication settings hub_token_url: HttpUrlString | None = None From bd4f07cc2766da8c2ec0b82d6db9f7e4081ea81c Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 14:42:39 -0400 Subject: [PATCH 04/20] fix comment --- polaris/hub/storage.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index d0b83109..dacc9a32 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -478,7 +478,7 @@ class StorageSession(OAuth2Client): A context manager for managing a storage session, with token exchange and token refresh capabilities. Each session is associated with a specific scope and resource. The `mode` parameter controls whether authentication is required: - - For 'read' mode, if a public_artifact_url is set, authentication is skipped. + - For 'read' mode, authentication is optional. - For 'write' mode, authentication is always required. """ @@ -501,7 +501,7 @@ def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn, mode: str = cert=hub_client.settings.ca_bundle, ) - def __enter__(self) -> Self: # Only skip authentication for reads with a public URL + def __enter__(self) -> Self: if not (self.mode == "read"): self.ensure_active_token() return self From 55f385b8e199253e9f6e8367bf4695aab700748b Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 14:50:45 -0400 Subject: [PATCH 05/20] cleaned up some more code --- polaris/hub/client.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index bb54d344..74cc979e 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -179,10 +179,10 @@ def fetch_token(self, **kwargs): message=f"Could not obtain a token to access the Hub. Error was: {error.error} - {error.description}" ) from error - def _base_request_to_hub(self, url: str, method: str, **kwargs) -> Response: - """Utility function since most API methods follow the same pattern""" + def _base_request_to_hub(self, url: str, method: str, withhold_token: bool = False, **kwargs) -> Response: + """Utility function for making requests to the Hub with consistent error handling.""" try: - response = self.request(url=url, method=method, **kwargs) + response = self.request(url=url, method=method, withhold_token=withhold_token, **kwargs) response.raise_for_status() return response except HTTPStatusError as error: @@ -258,7 +258,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: A list of dataset names in the format `owner/dataset_slug`. """ with track_progress(description="Fetching datasets", total=1): - v2_json_response = self.request(url="/v2/dataset", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() + v2_json_response = self._base_request_to_hub(url="/v2/dataset", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() v2_data = v2_json_response["data"] v2_datasets = [dataset["artifactId"] for dataset in v2_data] @@ -269,7 +269,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: # Step 2: Calculate the remaining limit and fetch v1 datasets remaining_limit = max(0, limit - len(v2_datasets)) - v1_json_response = self.request( + v1_json_response = self._base_request_to_hub( url="/v1/dataset", method="GET", withhold_token=True, @@ -328,7 +328,7 @@ def _get_v1_dataset( A `Dataset` instance, if it exists. """ url = f"/v1/dataset/{owner}/{slug}" - response = self.request(url=url, method="GET", withhold_token=True) + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() # Prefer table_path and zarr_path from response metadata if available @@ -353,7 +353,7 @@ def _get_v1_dataset( def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: """""" url = f"/v2/dataset/{owner}/{slug}" - response = self.request(url=url, method="GET", withhold_token=True) + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() metadata = response_data.get("metadata", {}) @@ -376,7 +376,7 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: """ with track_progress(description="Fetching benchmarks", total=1): # Step 1: Fetch enough v2 benchmarks to cover the offset and limit - v2_json_response = self.request(url="/v2/benchmark", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() + v2_json_response = self._base_request_to_hub(url="/v2/benchmark", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() v2_data = v2_json_response["data"] v2_benchmarks = [benchmark["artifactId"] for benchmark in v2_data] @@ -386,7 +386,7 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: # Step 2: Calculate the remaining limit and fetch v1 benchmarks remaining_limit = max(0, limit - len(v2_benchmarks)) - v1_json_response = self.request( + v1_json_response = self._base_request_to_hub( url="/v1/benchmark", method="GET", withhold_token=True, @@ -433,7 +433,7 @@ def _get_v1_benchmark( slug: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkV1Specification: - response = self.request(url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True) + response = self._base_request_to_hub(url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True) response_data = response.json() # TODO (jstlaurent): response["dataset"]["artifactId"] is the owner/name unique identifier, @@ -462,7 +462,7 @@ def _get_v1_benchmark( return benchmark def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Specification: - response = self.request(url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True) + response = self._base_request_to_hub(url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True) response_data = response.json() response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/")) @@ -593,7 +593,7 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: A `CompetitionSpecification` instance, if it exists. """ url = f"/v1/competition/{artifact_id}" - response = self.request(url=url, method="GET", withhold_token=True) + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() metadata = response_data.get("metadata", {}) @@ -640,7 +640,7 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: A list of models names in the format `owner/model_slug`. """ with track_progress(description="Fetching models", total=1): - json_response = self.request( + json_response = self._base_request_to_hub( url="/v2/model", method="GET", withhold_token=True, params={"limit": limit, "offset": offset} ).json() models = [model["artifactId"] for model in json_response["data"]] @@ -649,7 +649,7 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: def get_model(self, artifact_id: str) -> Model: url = f"/v2/model/{artifact_id}" - response = self.request(url=url, method="GET", withhold_token=True) + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() return Model(**response_data) From 0c05858cf693fe99a403aa945819229b39c5f24e Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 15:25:39 -0400 Subject: [PATCH 06/20] zarr path fix --- polaris/dataset/_base.py | 7 +++++-- polaris/dataset/_dataset.py | 2 +- polaris/dataset/_dataset_v2.py | 2 +- polaris/hub/client.py | 5 ++++- polaris/hub/storage.py | 3 ++- 5 files changed, 13 insertions(+), 6 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index e9ddadd1..a7b8a63e 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -171,7 +171,7 @@ def zarr_root(self) -> zarr.Group | None: See also `dataset.load_to_memory()`. """ - from polaris.hub.storage import StorageSession + from urllib.parse import urlparse if self._zarr_root is not None: return self._zarr_root @@ -179,7 +179,10 @@ def zarr_root(self) -> zarr.Group | None: if self.zarr_root_path is None: return None - saved_on_hub = self.zarr_root_path.startswith(StorageSession.polaris_protocol) + print(f"zarr_root_path: {self.zarr_root_path}") + parsed = urlparse(self.zarr_root_path) + saved_on_hub = parsed.scheme in ("http", "https", "s3", "gs") + print(f"saved_on_hub: {saved_on_hub}") if self._warn_about_remote_zarr and saved_on_hub: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 3b3fd3a5..f33bc1b8 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -140,7 +140,7 @@ def load_zarr_root_from_hub(self): from polaris.hub.storage import StorageSession import zarr - store = StorageSession.store(self.zarr_root_path) + store = StorageSession.store(self) return zarr.open_consolidated(store=store) @computed_field diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 5b097be3..20fbb2e3 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -120,7 +120,7 @@ def load_zarr_root_from_hub(self): from polaris.hub.storage import StorageSession import zarr - store = StorageSession.store(self.zarr_root_path) + store = StorageSession.store(self) return zarr.open_consolidated(store=store) @property diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 74cc979e..7b26923e 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -330,6 +330,8 @@ def _get_v1_dataset( url = f"/v1/dataset/{owner}/{slug}" response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() + response_data.pop("zarr_root_path", None) + response_data.pop("zarrRootPath", None) # Prefer table_path and zarr_path from response metadata if available metadata = response_data.get("metadata", {}) @@ -355,11 +357,12 @@ def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: url = f"/v2/dataset/{owner}/{slug}" response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() + response_data.pop("zarr_root_path", None) + response_data.pop("zarrRootPath", None) metadata = response_data.get("metadata", {}) zarr_path = metadata.get("zarr_path") # For v2 datasets, the zarr_path always exists - dataset = DatasetV2(zarr_root_path=zarr_path, **response_data) return dataset diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index dacc9a32..6991b2f3 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -599,8 +599,9 @@ def get_file(self, path: str) -> bytes | bytearray: with fsspec.open(path, mode='rb') as f: return f.read() - def store(self, path: str): + def store(self): """ Return a fsspec mapper for the given path, using the provided path as a full URL. """ + path = self.zarr_root_path return fsspec.get_mapper(path) From 066cc8982ec410bd270f1b6908157a495809c5f7 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 15:30:35 -0400 Subject: [PATCH 07/20] streamline hub check --- polaris/dataset/_base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index a7b8a63e..0691fbae 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -179,10 +179,8 @@ def zarr_root(self) -> zarr.Group | None: if self.zarr_root_path is None: return None - print(f"zarr_root_path: {self.zarr_root_path}") parsed = urlparse(self.zarr_root_path) - saved_on_hub = parsed.scheme in ("http", "https", "s3", "gs") - print(f"saved_on_hub: {saved_on_hub}") + saved_on_hub = parsed.scheme == "https" if self._warn_about_remote_zarr and saved_on_hub: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". From b3516292f5a55b2162a22ee476828b6e7d5d126d Mon Sep 17 00:00:00 2001 From: Jack Li Date: Tue, 20 May 2025 15:36:31 -0400 Subject: [PATCH 08/20] More ruff formatting --- polaris/hub/client.py | 24 +++++++++++++++++++----- polaris/hub/storage.py | 4 ++-- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 7b26923e..e7e42219 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -258,7 +258,12 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: A list of dataset names in the format `owner/dataset_slug`. """ with track_progress(description="Fetching datasets", total=1): - v2_json_response = self._base_request_to_hub(url="/v2/dataset", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() + v2_json_response = self._base_request_to_hub( + url="/v2/dataset", + method="GET", + withhold_token=True, + params={"limit": limit, "offset": offset}, + ).json() v2_data = v2_json_response["data"] v2_datasets = [dataset["artifactId"] for dataset in v2_data] @@ -341,7 +346,7 @@ def _get_v1_dataset( # Load the dataset table and optional Zarr archive with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage: table = pd.read_parquet(BytesIO(storage.get_file(table_path))) - + dataset = DatasetV1(table=table, zarr_root_path=zarr_path, **response_data) md5sum = response_data["md5Sum"] @@ -379,7 +384,12 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: """ with track_progress(description="Fetching benchmarks", total=1): # Step 1: Fetch enough v2 benchmarks to cover the offset and limit - v2_json_response = self._base_request_to_hub(url="/v2/benchmark", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}).json() + v2_json_response = self._base_request_to_hub( + url="/v2/benchmark", + method="GET", + withhold_token=True, + params={"limit": limit, "offset": offset}, + ).json() v2_data = v2_json_response["data"] v2_benchmarks = [benchmark["artifactId"] for benchmark in v2_data] @@ -436,7 +446,9 @@ def _get_v1_benchmark( slug: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkV1Specification: - response = self._base_request_to_hub(url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True) + response = self._base_request_to_hub( + url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True + ) response_data = response.json() # TODO (jstlaurent): response["dataset"]["artifactId"] is the owner/name unique identifier, @@ -465,7 +477,9 @@ def _get_v1_benchmark( return benchmark def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Specification: - response = self._base_request_to_hub(url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True) + response = self._base_request_to_hub( + url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True + ) response_data = response.json() response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/")) diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 6991b2f3..cac6db3e 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -478,7 +478,7 @@ class StorageSession(OAuth2Client): A context manager for managing a storage session, with token exchange and token refresh capabilities. Each session is associated with a specific scope and resource. The `mode` parameter controls whether authentication is required: - - For 'read' mode, authentication is optional. + - For 'read' mode, authentication is optional. - For 'write' mode, authentication is always required. """ @@ -596,7 +596,7 @@ def get_file(self, path: str) -> bytes | bytearray: Get the value at the given path. """ # The path is now a full URL, so we use it directly - with fsspec.open(path, mode='rb') as f: + with fsspec.open(path, mode="rb") as f: return f.read() def store(self): From d370591c3e55b118702dc329c81d75e103a3b330 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 21 May 2025 12:49:06 -0400 Subject: [PATCH 09/20] response changes --- polaris/hub/client.py | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index e7e42219..704742f7 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -335,19 +335,17 @@ def _get_v1_dataset( url = f"/v1/dataset/{owner}/{slug}" response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - response_data.pop("zarr_root_path", None) response_data.pop("zarrRootPath", None) + response_data.pop("zarr_root_path", None) - # Prefer table_path and zarr_path from response metadata if available - metadata = response_data.get("metadata", {}) - table_path = metadata.get("table_path") - zarr_path = metadata.get("zarr_path") + root_url = response_data.get("root") + extension_url = response_data.get("extension") # Load the dataset table and optional Zarr archive with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage: - table = pd.read_parquet(BytesIO(storage.get_file(table_path))) + table = pd.read_parquet(BytesIO(storage.get_file(root_url))) - dataset = DatasetV1(table=table, zarr_root_path=zarr_path, **response_data) + dataset = DatasetV1(table=table, zarr_root_path=extension_url, **response_data) md5sum = response_data["md5Sum"] if dataset.should_verify_checksum(verify_checksum): @@ -362,13 +360,13 @@ def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: url = f"/v2/dataset/{owner}/{slug}" response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - response_data.pop("zarr_root_path", None) response_data.pop("zarrRootPath", None) + response_data.pop("zarr_root_path", None) - metadata = response_data.get("metadata", {}) - zarr_path = metadata.get("zarr_path") + root_url = response_data.get("root") + print(f"Root URL: {root_url}") # For v2 datasets, the zarr_path always exists - dataset = DatasetV2(zarr_root_path=zarr_path, **response_data) + dataset = DatasetV2(zarr_root_path=root_url, **response_data) return dataset def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: @@ -613,10 +611,9 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - metadata = response_data.get("metadata", {}) - zarr_path = metadata.get("zarr_path") + root_url = response_data.get("root") - return CompetitionSpecification(zarr_root_path=zarr_path, **response_data) + return CompetitionSpecification(zarr_root_path=root_url, **response_data) def submit_competition_predictions( self, From 65fede08b406d8870e9ac46390b451f6d2d9f998 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 21 May 2025 15:35:28 -0400 Subject: [PATCH 10/20] Various fixes --- polaris/dataset/_dataset.py | 2 +- polaris/dataset/_dataset_v2.py | 2 +- polaris/hub/client.py | 16 +++++++--------- polaris/hub/storage.py | 10 +--------- 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index f33bc1b8..da791171 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -140,7 +140,7 @@ def load_zarr_root_from_hub(self): from polaris.hub.storage import StorageSession import zarr - store = StorageSession.store(self) + store = fsspec.get_mapper(self.zarr_root_path) return zarr.open_consolidated(store=store) @computed_field diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 20fbb2e3..8b23a83b 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -120,7 +120,7 @@ def load_zarr_root_from_hub(self): from polaris.hub.storage import StorageSession import zarr - store = StorageSession.store(self) + store = fsspec.get_mapper(self.zarr_root_path) return zarr.open_consolidated(store=store) @property diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 704742f7..b2a6a77c 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -179,8 +179,11 @@ def fetch_token(self, **kwargs): message=f"Could not obtain a token to access the Hub. Error was: {error.error} - {error.description}" ) from error - def _base_request_to_hub(self, url: str, method: str, withhold_token: bool = False, **kwargs) -> Response: + def _base_request_to_hub(self, url: str, method: str, withhold_token: bool, **kwargs) -> Response: """Utility function for making requests to the Hub with consistent error handling.""" + if not withhold_token: + if not self.ensure_active_token(): + raise PolarisUnauthorizedError() try: response = self.request(url=url, method=method, withhold_token=withhold_token, **kwargs) response.raise_for_status() @@ -336,7 +339,6 @@ def _get_v1_dataset( response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() response_data.pop("zarrRootPath", None) - response_data.pop("zarr_root_path", None) root_url = response_data.get("root") extension_url = response_data.get("extension") @@ -361,10 +363,8 @@ def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() response_data.pop("zarrRootPath", None) - response_data.pop("zarr_root_path", None) root_url = response_data.get("root") - print(f"Root URL: {root_url}") # For v2 datasets, the zarr_path always exists dataset = DatasetV2(zarr_root_path=root_url, **response_data) return dataset @@ -510,15 +510,13 @@ def upload_results( results: The results to upload. owner: Which Hub user or organization owns the artifact. Takes precedence over `results.owner`. """ - if not self.ensure_active_token(): - raise PolarisUnauthorizedError() with track_progress(description="Uploading results", total=1) as (progress, task): # Get the serialized model data-structure results.owner = HubOwner.normalize(owner or results.owner) result_json = results.model_dump(by_alias=True, exclude_none=True) # Make a request to the Hub - response = self._base_request_to_hub(url="/v2/result", method="POST", json=result_json) + response = self._base_request_to_hub(url="/v2/result", method="POST", withhold_token=False, json=result_json) # Inform the user about where to find their newly created artifact. result_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) @@ -639,6 +637,7 @@ def submit_competition_predictions( response = self._base_request_to_hub( url="/v1/competition-prediction", method="POST", + withhold_token=False, json=prediction_payload, ) return response @@ -691,8 +690,6 @@ def upload_model( owner: Which Hub user or organization owns the artifact. Takes precedence over `model.owner`. parent_artifact_id: The `owner/slug` of the parent model, if uploading a new version of a model. """ - if not self.ensure_active_token(): - raise PolarisUnauthorizedError() with track_progress(description="Uploading model", total=1) as (progress, task): # Get the serialized model data-structure model.owner = HubOwner.normalize(owner or model.owner) @@ -703,6 +700,7 @@ def upload_model( response = self._base_request_to_hub( url=url, method="PUT", + withhold_token=False, json={"parentArtifactId": parent_artifact_id, **model_json}, ) diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index cac6db3e..99bd4f19 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -588,8 +588,7 @@ def set_file(self, path: str, value: bytes | bytearray): content_type=content_type, ) # Use StorageSession with mode='write' for write operations - with StorageSession(self.hub_client, "write", self.resource, mode="write"): - store[relative_path.name] = value + store[relative_path.name] = value def get_file(self, path: str) -> bytes | bytearray: """ @@ -598,10 +597,3 @@ def get_file(self, path: str) -> bytes | bytearray: # The path is now a full URL, so we use it directly with fsspec.open(path, mode="rb") as f: return f.read() - - def store(self): - """ - Return a fsspec mapper for the given path, using the provided path as a full URL. - """ - path = self.zarr_root_path - return fsspec.get_mapper(path) From 470ddfd84841800fd14a36f028673b38031b4842 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Wed, 21 May 2025 15:44:26 -0400 Subject: [PATCH 11/20] added bucket url to polaris settings --- polaris/dataset/_base.py | 5 +++-- polaris/hub/client.py | 4 +++- polaris/hub/settings.py | 3 +++ 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 0691fbae..f023c1b3 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -36,6 +36,7 @@ SupportedLicenseType, ZarrConflictResolution, ) +from polaris.hub.settings import PolarisHubSettings logger = logging.getLogger(__name__) @@ -179,8 +180,8 @@ def zarr_root(self) -> zarr.Group | None: if self.zarr_root_path is None: return None - parsed = urlparse(self.zarr_root_path) - saved_on_hub = parsed.scheme == "https" + settings = PolarisHubSettings() + saved_on_hub = self.zarr_root_path and self.zarr_root_path.startswith(settings.bucket_url) if self._warn_about_remote_zarr and saved_on_hub: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". diff --git a/polaris/hub/client.py b/polaris/hub/client.py index b2a6a77c..b5588d14 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -516,7 +516,9 @@ def upload_results( result_json = results.model_dump(by_alias=True, exclude_none=True) # Make a request to the Hub - response = self._base_request_to_hub(url="/v2/result", method="POST", withhold_token=False, json=result_json) + response = self._base_request_to_hub( + url="/v2/result", method="POST", withhold_token=False, json=result_json + ) # Inform the user about where to find their newly created artifact. result_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index 042b7d59..035819a5 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -59,6 +59,9 @@ class PolarisHubSettings(BaseSettings): ca_bundle: str | bool | None = None default_timeout: TimeoutTypes = (10, 200) + # Bucket settings + bucket_url: str = "https://data.polarishub.io" + @field_validator("api_url", mode="before") def validate_api_url(cls, v, info: ValidationInfo): if v is None: From d5b5eae412eb1e64f126775fa07d12812e024868 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 22 May 2025 10:42:54 -0400 Subject: [PATCH 12/20] documentation change --- docs/quickstart.md | 4 ++-- docs/tutorials/submit_to_benchmark.ipynb | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 33bb0add..5a82c792 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -25,7 +25,7 @@ Polaris explicitly distinguished **datasets** and **benchmarks**. One dataset can therefore be associated with multiple benchmarks. ## Login -To interact with the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up). +To submit or upload artifacts to the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up). You can do this via the following command in your terminal: @@ -66,7 +66,7 @@ predictions = [0.0 for x in test] results = benchmark.evaluate(predictions) # Submit your results -results.upload_to_hub(owner="dummy-user", access="public") +results.upload_to_hub(owner="dummy-user") ``` Through immutable datasets and standardized benchmarks, Polaris aims to serve as a source of truth for machine learning in drug discovery. The limited flexibility might differ from your typical experience, but this is by design to improve reproducibility. Learn more [here](https://polarishub.io/blog/reproducible-machine-learning-in-drug-discovery-how-polaris-serves-as-a-single-source-of-truth). diff --git a/docs/tutorials/submit_to_benchmark.ipynb b/docs/tutorials/submit_to_benchmark.ipynb index cc05c00a..53c5c93a 100644 --- a/docs/tutorials/submit_to_benchmark.ipynb +++ b/docs/tutorials/submit_to_benchmark.ipynb @@ -334,7 +334,7 @@ }, "outputs": [], "source": [ - "results.upload_to_hub(owner=\"my-username\", access=\"public\")" + "results.upload_to_hub(owner=\"my-username\")" ] }, { From e161c565d5657e4e9f6e5fd40115f3bd5fb7815a Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 22 May 2025 16:28:25 -0400 Subject: [PATCH 13/20] Updates to client, imports and storagesession class --- polaris/dataset/_base.py | 10 +++++----- polaris/dataset/_dataset.py | 3 +-- polaris/dataset/_dataset_v2.py | 3 --- polaris/hub/client.py | 14 ++++++++------ polaris/hub/storage.py | 13 +------------ 5 files changed, 15 insertions(+), 28 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index f023c1b3..9d59e9c6 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -180,10 +180,10 @@ def zarr_root(self) -> zarr.Group | None: if self.zarr_root_path is None: return None - settings = PolarisHubSettings() - saved_on_hub = self.zarr_root_path and self.zarr_root_path.startswith(settings.bucket_url) - - if self._warn_about_remote_zarr and saved_on_hub: + fs, _ = fsspec.url_to_fs(self.zarr_root_path) + remote = 'local' in fs.protocol + + if self._warn_about_remote_zarr and remote: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". logger.warning( f"You're loading data from a remote location. " @@ -193,7 +193,7 @@ def zarr_root(self) -> zarr.Group | None: self._warn_about_remote_zarr = False try: - if saved_on_hub: + if remote: self._zarr_root = self.load_zarr_root_from_hub() else: self._zarr_root = self.load_zarr_root_from_local() diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index da791171..eabdf50c 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -6,6 +6,7 @@ import fsspec import numpy as np +import zarr import pandas as pd from datamol.utils import fs as dmfs from pydantic import PrivateAttr, computed_field, field_validator, model_validator @@ -137,8 +138,6 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.storage import StorageSession - import zarr store = fsspec.get_mapper(self.zarr_root_path) return zarr.open_consolidated(store=store) diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 8b23a83b..de5d7f4d 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -117,9 +117,6 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.storage import StorageSession - import zarr - store = fsspec.get_mapper(self.zarr_root_path) return zarr.open_consolidated(store=store) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index b5588d14..66e89745 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -11,6 +11,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token from httpx import ConnectError, HTTPStatusError, Response from typing_extensions import Self +import fsspec from polaris.benchmark import ( BenchmarkV1Specification, @@ -344,8 +345,8 @@ def _get_v1_dataset( extension_url = response_data.get("extension") # Load the dataset table and optional Zarr archive - with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage: - table = pd.read_parquet(BytesIO(storage.get_file(root_url))) + with fsspec.open(root_url, mode='rb') as f: + table = pd.read_parquet(f) dataset = DatasetV1(table=table, zarr_root_path=extension_url, **response_data) md5sum = response_data["md5Sum"] @@ -482,10 +483,11 @@ def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Spec response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/")) - # Load the split index sets - with StorageSession(self, "read", BenchmarkV2Specification.urn_for(owner, slug)) as storage: - split = {label: storage.get_file(label) for label in response_data.get("split", {}).keys()} - + split = {} + for label, url in response_data.get("split", {}).items(): + with fsspec.open(url, mode="rb") as f: + split[label] = f.read() + return BenchmarkV2Specification(**{**response_data, "split": split}) def upload_results( diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 99bd4f19..85af2f82 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -475,11 +475,8 @@ def set_token(self, token: dict[str, Any] | HubStorageOAuth2Token) -> None: class StorageSession(OAuth2Client): """ - A context manager for managing a storage session, with token exchange and token refresh capabilities. + A context manager for managing a storage session for upload/write operations, with token exchange and token refresh capabilities. Each session is associated with a specific scope and resource. - The `mode` parameter controls whether authentication is required: - - For 'read' mode, authentication is optional. - - For 'write' mode, authentication is always required. """ polaris_protocol = "polarisfs" @@ -589,11 +586,3 @@ def set_file(self, path: str, value: bytes | bytearray): ) # Use StorageSession with mode='write' for write operations store[relative_path.name] = value - - def get_file(self, path: str) -> bytes | bytearray: - """ - Get the value at the given path. - """ - # The path is now a full URL, so we use it directly - with fsspec.open(path, mode="rb") as f: - return f.read() From b39427468d43d5e48c3f3471ce8fde1c4cbabadc Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 22 May 2025 16:29:04 -0400 Subject: [PATCH 14/20] ruffle formatting --- polaris/dataset/_base.py | 4 ++-- polaris/hub/client.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index 9d59e9c6..a760548c 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -181,8 +181,8 @@ def zarr_root(self) -> zarr.Group | None: return None fs, _ = fsspec.url_to_fs(self.zarr_root_path) - remote = 'local' in fs.protocol - + remote = "local" in fs.protocol + if self._warn_about_remote_zarr and remote: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". logger.warning( diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 66e89745..6d821baf 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -345,7 +345,7 @@ def _get_v1_dataset( extension_url = response_data.get("extension") # Load the dataset table and optional Zarr archive - with fsspec.open(root_url, mode='rb') as f: + with fsspec.open(root_url, mode="rb") as f: table = pd.read_parquet(f) dataset = DatasetV1(table=table, zarr_root_path=extension_url, **response_data) @@ -487,7 +487,7 @@ def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Spec for label, url in response_data.get("split", {}).items(): with fsspec.open(url, mode="rb") as f: split[label] = f.read() - + return BenchmarkV2Specification(**{**response_data, "split": split}) def upload_results( From 282f14b4bbd24e2b527da5e59275ccee41c50d42 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 22 May 2025 16:36:07 -0400 Subject: [PATCH 15/20] fix remote check --- polaris/dataset/_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index a760548c..c49b0104 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -181,7 +181,9 @@ def zarr_root(self) -> zarr.Group | None: return None fs, _ = fsspec.url_to_fs(self.zarr_root_path) - remote = "local" in fs.protocol + remote = not ( + fs.protocol == "file" or (isinstance(fs.protocol, (list, tuple)) and fs.protocol[0] == "file") + ) if self._warn_about_remote_zarr and remote: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". From de6b4c3ca6222d85b7529a514436714ce8572724 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Thu, 22 May 2025 17:49:46 -0400 Subject: [PATCH 16/20] remove bucket url --- polaris/hub/settings.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/polaris/hub/settings.py b/polaris/hub/settings.py index 035819a5..042b7d59 100644 --- a/polaris/hub/settings.py +++ b/polaris/hub/settings.py @@ -59,9 +59,6 @@ class PolarisHubSettings(BaseSettings): ca_bundle: str | bool | None = None default_timeout: TimeoutTypes = (10, 200) - # Bucket settings - bucket_url: str = "https://data.polarishub.io" - @field_validator("api_url", mode="before") def validate_api_url(cls, v, info: ValidationInfo): if v is None: From 5d8fa7aff59a5ad21e92f70039852da9cd8074e7 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 23 May 2025 11:13:10 -0400 Subject: [PATCH 17/20] remove imports --- polaris/dataset/_base.py | 3 --- polaris/hub/client.py | 4 +--- polaris/hub/storage.py | 9 +++------ 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index c49b0104..0076b03b 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -36,7 +36,6 @@ SupportedLicenseType, ZarrConflictResolution, ) -from polaris.hub.settings import PolarisHubSettings logger = logging.getLogger(__name__) @@ -172,8 +171,6 @@ def zarr_root(self) -> zarr.Group | None: See also `dataset.load_to_memory()`. """ - from urllib.parse import urlparse - if self._zarr_root is not None: return self._zarr_root diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 6d821baf..eee11c00 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -1,6 +1,5 @@ import json import logging -from io import BytesIO from urllib.parse import urljoin import httpx @@ -21,12 +20,11 @@ from polaris.benchmark._benchmark_v2 import BenchmarkV2Specification from polaris.competition import CompetitionSpecification from polaris.model import Model -from polaris.dataset import Dataset, DatasetV1, DatasetV2 +from polaris.dataset import DatasetV1, DatasetV2 from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions from polaris.hub.external_client import ExternalAuthClient from polaris.hub.oauth import CachedTokenAuth from polaris.hub.settings import PolarisHubSettings -from polaris.hub.storage import StorageSession from polaris.utils.context import track_progress from polaris.utils.errors import ( PolarisCreateArtifactError, diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 85af2f82..8e72d0b8 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -18,7 +18,6 @@ from zarr.context import Context from zarr.storage import Store from zarr.util import buffer_size -import fsspec from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token from polaris.utils.context import track_progress @@ -483,10 +482,9 @@ class StorageSession(OAuth2Client): token_auth_class = StorageTokenAuth - def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn, mode: str = "read"): + def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn): self.hub_client = hub_client self.resource = resource - self.mode = mode super().__init__( # OAuth2Client @@ -499,8 +497,7 @@ def __init__(self, hub_client, scope: Scope, resource: ArtifactUrn, mode: str = ) def __enter__(self) -> Self: - if not (self.mode == "read"): - self.ensure_active_token() + self.ensure_active_token() return self def _prepare_token_endpoint_body(self, body, grant_type, **kwargs) -> str: @@ -584,5 +581,5 @@ def set_file(self, path: str, value: bytes | bytearray): endpoint_url=storage_data.endpoint, content_type=content_type, ) - # Use StorageSession with mode='write' for write operations + store[relative_path.name] = value From f4d6417482ee8955f3521d5ad1c49ac8e46a1fea Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 23 May 2025 14:40:59 -0400 Subject: [PATCH 18/20] update test --- tests/test_hub_integration.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/tests/test_hub_integration.py b/tests/test_hub_integration.py index 9f52a030..6f22c92f 100644 --- a/tests/test_hub_integration.py +++ b/tests/test_hub_integration.py @@ -8,19 +8,10 @@ settings = PolarisHubSettings() -@pytest.mark.skipif( - settings.username is None or settings.password is None, - reason="This test case requires headless authentication to be set up", -) def test_load_dataset_flow(): dataset = po.load_dataset("polaris/hello-world") assert isinstance(dataset, BaseDataset) - -@pytest.mark.skipif( - settings.username is None or settings.password is None, - reason="This test case requires headless authentication to be set up", -) def test_load_benchmark_flow(): benchmark = po.load_benchmark("polaris/hello-world-benchmark") assert isinstance(benchmark, BenchmarkSpecification) From d34a6d7cfa591e5f680e39cacf4c9b91744cba1c Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 23 May 2025 14:49:43 -0400 Subject: [PATCH 19/20] update tests 2 --- polaris/hub/client.py | 11 +++++++++++ tests/test_hub_integration.py | 3 +-- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index eee11c00..fe30205c 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -342,6 +342,11 @@ def _get_v1_dataset( root_url = response_data.get("root") extension_url = response_data.get("extension") + if root_url is not None: + root_url = root_url.strip('"') + if extension_url is not None: + extension_url = extension_url.strip('"') + # Load the dataset table and optional Zarr archive with fsspec.open(root_url, mode="rb") as f: table = pd.read_parquet(f) @@ -364,6 +369,9 @@ def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: response_data.pop("zarrRootPath", None) root_url = response_data.get("root") + + if root_url is not None: + root_url = root_url.strip('"') # For v2 datasets, the zarr_path always exists dataset = DatasetV2(zarr_root_path=root_url, **response_data) return dataset @@ -613,6 +621,9 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: root_url = response_data.get("root") + if root_url is not None: + root_url = root_url.strip('"') + return CompetitionSpecification(zarr_root_path=root_url, **response_data) def submit_competition_predictions( diff --git a/tests/test_hub_integration.py b/tests/test_hub_integration.py index 6f22c92f..74ac7f8d 100644 --- a/tests/test_hub_integration.py +++ b/tests/test_hub_integration.py @@ -1,5 +1,3 @@ -import pytest - import polaris as po from polaris.benchmark._base import BenchmarkSpecification from polaris.dataset._base import BaseDataset @@ -12,6 +10,7 @@ def test_load_dataset_flow(): dataset = po.load_dataset("polaris/hello-world") assert isinstance(dataset, BaseDataset) + def test_load_benchmark_flow(): benchmark = po.load_benchmark("polaris/hello-world-benchmark") assert isinstance(benchmark, BenchmarkSpecification) From e8c0b8eb934b0e7db528c1e616bd7b3ed2540135 Mon Sep 17 00:00:00 2001 From: Jack Li Date: Fri, 23 May 2025 15:18:10 -0400 Subject: [PATCH 20/20] test fix --- env.yml | 2 ++ polaris/hub/client.py | 10 ---------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/env.yml b/env.yml index d466ab92..dd829b69 100644 --- a/env.yml +++ b/env.yml @@ -17,6 +17,8 @@ dependencies: # Hub client - authlib - httpx + - requests + - aiohttp # Scientific - numpy < 3 diff --git a/polaris/hub/client.py b/polaris/hub/client.py index fe30205c..423d7168 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -342,11 +342,6 @@ def _get_v1_dataset( root_url = response_data.get("root") extension_url = response_data.get("extension") - if root_url is not None: - root_url = root_url.strip('"') - if extension_url is not None: - extension_url = extension_url.strip('"') - # Load the dataset table and optional Zarr archive with fsspec.open(root_url, mode="rb") as f: table = pd.read_parquet(f) @@ -370,8 +365,6 @@ def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: root_url = response_data.get("root") - if root_url is not None: - root_url = root_url.strip('"') # For v2 datasets, the zarr_path always exists dataset = DatasetV2(zarr_root_path=root_url, **response_data) return dataset @@ -621,9 +614,6 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: root_url = response_data.get("root") - if root_url is not None: - root_url = root_url.strip('"') - return CompetitionSpecification(zarr_root_path=root_url, **response_data) def submit_competition_predictions(