diff --git a/docs/quickstart.md b/docs/quickstart.md index 33bb0add..5a82c792 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -25,7 +25,7 @@ Polaris explicitly distinguished **datasets** and **benchmarks**. One dataset can therefore be associated with multiple benchmarks. ## Login -To interact with the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up). +To submit or upload artifacts to the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up). You can do this via the following command in your terminal: @@ -66,7 +66,7 @@ predictions = [0.0 for x in test] results = benchmark.evaluate(predictions) # Submit your results -results.upload_to_hub(owner="dummy-user", access="public") +results.upload_to_hub(owner="dummy-user") ``` Through immutable datasets and standardized benchmarks, Polaris aims to serve as a source of truth for machine learning in drug discovery. The limited flexibility might differ from your typical experience, but this is by design to improve reproducibility. Learn more [here](https://polarishub.io/blog/reproducible-machine-learning-in-drug-discovery-how-polaris-serves-as-a-single-source-of-truth). diff --git a/docs/tutorials/submit_to_benchmark.ipynb b/docs/tutorials/submit_to_benchmark.ipynb index cc05c00a..53c5c93a 100644 --- a/docs/tutorials/submit_to_benchmark.ipynb +++ b/docs/tutorials/submit_to_benchmark.ipynb @@ -334,7 +334,7 @@ }, "outputs": [], "source": [ - "results.upload_to_hub(owner=\"my-username\", access=\"public\")" + "results.upload_to_hub(owner=\"my-username\")" ] }, { diff --git a/env.yml b/env.yml index d466ab92..dd829b69 100644 --- a/env.yml +++ b/env.yml @@ -17,6 +17,8 @@ dependencies: # Hub client - authlib - httpx + - requests + - aiohttp # Scientific - numpy < 3 diff --git a/polaris/dataset/_base.py b/polaris/dataset/_base.py index e9ddadd1..0076b03b 100644 --- a/polaris/dataset/_base.py +++ b/polaris/dataset/_base.py @@ -171,17 +171,18 @@ def zarr_root(self) -> zarr.Group | None: See also `dataset.load_to_memory()`. """ - from polaris.hub.storage import StorageSession - if self._zarr_root is not None: return self._zarr_root if self.zarr_root_path is None: return None - saved_on_hub = self.zarr_root_path.startswith(StorageSession.polaris_protocol) + fs, _ = fsspec.url_to_fs(self.zarr_root_path) + remote = not ( + fs.protocol == "file" or (isinstance(fs.protocol, (list, tuple)) and fs.protocol[0] == "file") + ) - if self._warn_about_remote_zarr and saved_on_hub: + if self._warn_about_remote_zarr and remote: # TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough". logger.warning( f"You're loading data from a remote location. " @@ -191,7 +192,7 @@ def zarr_root(self) -> zarr.Group | None: self._warn_about_remote_zarr = False try: - if saved_on_hub: + if remote: self._zarr_root = self.load_zarr_root_from_hub() else: self._zarr_root = self.load_zarr_root_from_local() diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 30383749..eabdf50c 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -6,8 +6,8 @@ import fsspec import numpy as np -import pandas as pd import zarr +import pandas as pd from datamol.utils import fs as dmfs from pydantic import PrivateAttr, computed_field, field_validator, model_validator from typing_extensions import Self, deprecated @@ -138,12 +138,9 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.client import PolarisHubClient - from polaris.hub.storage import StorageSession - with PolarisHubClient() as client: - with StorageSession(client, "read", self.urn) as storage: - return zarr.open_consolidated(store=storage.store("extension")) + store = fsspec.get_mapper(self.zarr_root_path) + return zarr.open_consolidated(store=store) @computed_field @property diff --git a/polaris/dataset/_dataset_v2.py b/polaris/dataset/_dataset_v2.py index 8a757bdd..de5d7f4d 100644 --- a/polaris/dataset/_dataset_v2.py +++ b/polaris/dataset/_dataset_v2.py @@ -117,12 +117,8 @@ def load_zarr_root_from_hub(self): """ Loads a Zarr archive from the Hub. """ - from polaris.hub.client import PolarisHubClient - from polaris.hub.storage import StorageSession - - with PolarisHubClient() as client: - with StorageSession(client, "read", self.urn) as storage: - return zarr.open_consolidated(store=storage.store("root")) + store = fsspec.get_mapper(self.zarr_root_path) + return zarr.open_consolidated(store=store) @property def zarr_manifest_path(self) -> str: diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 439ac4af..423d7168 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -1,6 +1,5 @@ import json import logging -from io import BytesIO from urllib.parse import urljoin import httpx @@ -11,6 +10,7 @@ from authlib.oauth2.rfc6749 import OAuth2Token from httpx import ConnectError, HTTPStatusError, Response from typing_extensions import Self +import fsspec from polaris.benchmark import ( BenchmarkV1Specification, @@ -20,12 +20,11 @@ from polaris.benchmark._benchmark_v2 import BenchmarkV2Specification from polaris.competition import CompetitionSpecification from polaris.model import Model -from polaris.dataset import Dataset, DatasetV1, DatasetV2 +from polaris.dataset import DatasetV1, DatasetV2 from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions from polaris.hub.external_client import ExternalAuthClient from polaris.hub.oauth import CachedTokenAuth from polaris.hub.settings import PolarisHubSettings -from polaris.hub.storage import StorageSession from polaris.utils.context import track_progress from polaris.utils.errors import ( PolarisCreateArtifactError, @@ -119,12 +118,7 @@ def __init__( ) def __enter__(self: Self) -> Self: - """ - When used as a context manager, automatically check that authentication is valid. - """ super().__enter__() - if not self.ensure_active_token(): - raise PolarisUnauthorizedError() return self @property @@ -184,10 +178,13 @@ def fetch_token(self, **kwargs): message=f"Could not obtain a token to access the Hub. Error was: {error.error} - {error.description}" ) from error - def _base_request_to_hub(self, url: str, method: str, **kwargs) -> Response: - """Utility function since most API methods follow the same pattern""" + def _base_request_to_hub(self, url: str, method: str, withhold_token: bool, **kwargs) -> Response: + """Utility function for making requests to the Hub with consistent error handling.""" + if not withhold_token: + if not self.ensure_active_token(): + raise PolarisUnauthorizedError() try: - response = self.request(url=url, method=method, **kwargs) + response = self.request(url=url, method=method, withhold_token=withhold_token, **kwargs) response.raise_for_status() return response except HTTPStatusError as error: @@ -263,9 +260,11 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: A list of dataset names in the format `owner/dataset_slug`. """ with track_progress(description="Fetching datasets", total=1): - # Step 1: Fetch enough v2 datasets to cover the offset and limit v2_json_response = self._base_request_to_hub( - url="/v2/dataset", method="GET", params={"limit": limit, "offset": offset} + url="/v2/dataset", + method="GET", + withhold_token=True, + params={"limit": limit, "offset": offset}, ).json() v2_data = v2_json_response["data"] v2_datasets = [dataset["artifactId"] for dataset in v2_data] @@ -280,6 +279,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: v1_json_response = self._base_request_to_hub( url="/v1/dataset", method="GET", + withhold_token=True, params={ "limit": remaining_limit, "offset": max(0, offset - v2_json_response["metadata"]["total"]), @@ -335,23 +335,18 @@ def _get_v1_dataset( A `Dataset` instance, if it exists. """ url = f"/v1/dataset/{owner}/{slug}" - response = self._base_request_to_hub(url=url, method="GET") + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - - # Disregard the Zarr root in the response. We'll get it from the storage token instead. response_data.pop("zarrRootPath", None) - # Load the dataset table and optional Zarr archive - with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage: - table = pd.read_parquet(BytesIO(storage.get_file("root"))) - zarr_root_path = storage.paths.extension + root_url = response_data.get("root") + extension_url = response_data.get("extension") - if zarr_root_path is not None: - # For V1 datasets, the Zarr Root is optional. - # It should be None if the dataset does not use pointer columns - zarr_root_path = str(zarr_root_path) + # Load the dataset table and optional Zarr archive + with fsspec.open(root_url, mode="rb") as f: + table = pd.read_parquet(f) - dataset = DatasetV1(table=table, zarr_root_path=zarr_root_path, **response_data) + dataset = DatasetV1(table=table, zarr_root_path=extension_url, **response_data) md5sum = response_data["md5Sum"] if dataset.should_verify_checksum(verify_checksum): @@ -364,17 +359,14 @@ def _get_v1_dataset( def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2: """""" url = f"/v2/dataset/{owner}/{slug}" - response = self._base_request_to_hub(url=url, method="GET") + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - - # Disregard the Zarr root in the response. We'll get it from the storage token instead. response_data.pop("zarrRootPath", None) - # Load the Zarr archive - with StorageSession(self, "read", DatasetV2.urn_for(owner, slug)) as storage: - zarr_root_path = str(storage.paths.root) + root_url = response_data.get("root") - dataset = DatasetV2(zarr_root_path=zarr_root_path, **response_data) + # For v2 datasets, the zarr_path always exists + dataset = DatasetV2(zarr_root_path=root_url, **response_data) return dataset def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: @@ -391,7 +383,10 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: with track_progress(description="Fetching benchmarks", total=1): # Step 1: Fetch enough v2 benchmarks to cover the offset and limit v2_json_response = self._base_request_to_hub( - url="/v2/benchmark", method="GET", params={"limit": limit, "offset": offset} + url="/v2/benchmark", + method="GET", + withhold_token=True, + params={"limit": limit, "offset": offset}, ).json() v2_data = v2_json_response["data"] v2_benchmarks = [benchmark["artifactId"] for benchmark in v2_data] @@ -402,10 +397,10 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]: # Step 2: Calculate the remaining limit and fetch v1 benchmarks remaining_limit = max(0, limit - len(v2_benchmarks)) - v1_json_response = self._base_request_to_hub( url="/v1/benchmark", method="GET", + withhold_token=True, params={ "limit": remaining_limit, "offset": max(0, offset - v2_json_response["metadata"]["total"]), @@ -449,7 +444,9 @@ def _get_v1_benchmark( slug: str, verify_checksum: ChecksumStrategy = "verify_unless_zarr", ) -> BenchmarkV1Specification: - response = self._base_request_to_hub(url=f"/v1/benchmark/{owner}/{slug}", method="GET") + response = self._base_request_to_hub( + url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True + ) response_data = response.json() # TODO (jstlaurent): response["dataset"]["artifactId"] is the owner/name unique identifier, @@ -478,14 +475,17 @@ def _get_v1_benchmark( return benchmark def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Specification: - response = self._base_request_to_hub(url=f"/v2/benchmark/{owner}/{slug}", method="GET") + response = self._base_request_to_hub( + url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True + ) response_data = response.json() response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/")) - # Load the split index sets - with StorageSession(self, "read", BenchmarkV2Specification.urn_for(owner, slug)) as storage: - split = {label: storage.get_file(label) for label in response_data.get("split", {}).keys()} + split = {} + for label, url in response_data.get("split", {}).items(): + with fsspec.open(url, mode="rb") as f: + split[label] = f.read() return BenchmarkV2Specification(**{**response_data, "split": split}) @@ -517,7 +517,9 @@ def upload_results( result_json = results.model_dump(by_alias=True, exclude_none=True) # Make a request to the Hub - response = self._base_request_to_hub(url="/v2/result", method="POST", json=result_json) + response = self._base_request_to_hub( + url="/v2/result", method="POST", withhold_token=False, json=result_json + ) # Inform the user about where to find their newly created artifact. result_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location")) @@ -607,15 +609,12 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification: A `CompetitionSpecification` instance, if it exists. """ url = f"/v1/competition/{artifact_id}" - response = self._base_request_to_hub(url=url, method="GET") + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() - with StorageSession( - self, "read", CompetitionSpecification.urn_for(*artifact_id.split("/")) - ) as storage: - zarr_root_path = str(storage.paths.root) + root_url = response_data.get("root") - return CompetitionSpecification(zarr_root_path=zarr_root_path, **response_data) + return CompetitionSpecification(zarr_root_path=root_url, **response_data) def submit_competition_predictions( self, @@ -641,6 +640,7 @@ def submit_competition_predictions( response = self._base_request_to_hub( url="/v1/competition-prediction", method="POST", + withhold_token=False, json=prediction_payload, ) return response @@ -657,7 +657,7 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: """ with track_progress(description="Fetching models", total=1): json_response = self._base_request_to_hub( - url="/v2/model", method="GET", params={"limit": limit, "offset": offset} + url="/v2/model", method="GET", withhold_token=True, params={"limit": limit, "offset": offset} ).json() models = [model["artifactId"] for model in json_response["data"]] @@ -665,7 +665,7 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]: def get_model(self, artifact_id: str) -> Model: url = f"/v2/model/{artifact_id}" - response = self._base_request_to_hub(url=url, method="GET") + response = self._base_request_to_hub(url=url, method="GET", withhold_token=True) response_data = response.json() return Model(**response_data) @@ -703,6 +703,7 @@ def upload_model( response = self._base_request_to_hub( url=url, method="PUT", + withhold_token=False, json={"parentArtifactId": parent_artifact_id, **model_json}, ) diff --git a/polaris/hub/storage.py b/polaris/hub/storage.py index 4e225e52..8e72d0b8 100644 --- a/polaris/hub/storage.py +++ b/polaris/hub/storage.py @@ -474,7 +474,7 @@ def set_token(self, token: dict[str, Any] | HubStorageOAuth2Token) -> None: class StorageSession(OAuth2Client): """ - A context manager for managing a storage session, with token exchange and token refresh capabilities. + A context manager for managing a storage session for upload/write operations, with token exchange and token refresh capabilities. Each session is associated with a specific scope and resource. """ @@ -581,42 +581,5 @@ def set_file(self, path: str, value: bytes | bytearray): endpoint_url=storage_data.endpoint, content_type=content_type, ) - store[relative_path.name] = value - - def get_file(self, path: str) -> bytes | bytearray: - """ - Get the value at the given path. - """ - if path not in self.paths.files: - raise NotImplementedError( - f"{type(self.paths).__name__} only supports these files: {self.paths.files}." - ) - - relative_path = self._relative_path(getattr(self.paths, path)) - - storage_data = self.token.extra_data - store = S3Store( - path=relative_path.parent, - access_key=storage_data.key, - secret_key=storage_data.secret, - token=f"jwt/{self.token.access_token}", - endpoint_url=storage_data.endpoint, - ) - return store[relative_path.name] - - def store(self, path: str) -> S3Store: - if path not in self.paths.stores: - raise NotImplementedError( - f"{type(self.paths).__name__} only supports these stores: {self.paths.stores}." - ) - - relative_path = self._relative_path(getattr(self.paths, path)) - storage_data = self.token.extra_data - return S3Store( - path=relative_path, - access_key=storage_data.key, - secret_key=storage_data.secret, - token=f"jwt/{self.token.access_token}", - endpoint_url=storage_data.endpoint, - ) + store[relative_path.name] = value diff --git a/tests/test_hub_integration.py b/tests/test_hub_integration.py index 9f52a030..74ac7f8d 100644 --- a/tests/test_hub_integration.py +++ b/tests/test_hub_integration.py @@ -1,5 +1,3 @@ -import pytest - import polaris as po from polaris.benchmark._base import BenchmarkSpecification from polaris.dataset._base import BaseDataset @@ -8,19 +6,11 @@ settings = PolarisHubSettings() -@pytest.mark.skipif( - settings.username is None or settings.password is None, - reason="This test case requires headless authentication to be set up", -) def test_load_dataset_flow(): dataset = po.load_dataset("polaris/hello-world") assert isinstance(dataset, BaseDataset) -@pytest.mark.skipif( - settings.username is None or settings.password is None, - reason="This test case requires headless authentication to be set up", -) def test_load_benchmark_flow(): benchmark = po.load_benchmark("polaris/hello-world-benchmark") assert isinstance(benchmark, BenchmarkSpecification)