Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/quickstart.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Polaris explicitly distinguished **datasets** and **benchmarks**.
One dataset can therefore be associated with multiple benchmarks.

## Login
To interact with the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up).
To submit or upload artifacts to the [Polaris Hub](https://polarishub.io/) from the client, you must first authenticate yourself. If you don't have an account yet, you can create one [here](https://polarishub.io/sign-up).

You can do this via the following command in your terminal:

Expand Down Expand Up @@ -66,7 +66,7 @@ predictions = [0.0 for x in test]
results = benchmark.evaluate(predictions)

# Submit your results
results.upload_to_hub(owner="dummy-user", access="public")
results.upload_to_hub(owner="dummy-user")
```

Through immutable datasets and standardized benchmarks, Polaris aims to serve as a source of truth for machine learning in drug discovery. The limited flexibility might differ from your typical experience, but this is by design to improve reproducibility. Learn more [here](https://polarishub.io/blog/reproducible-machine-learning-in-drug-discovery-how-polaris-serves-as-a-single-source-of-truth).
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/submit_to_benchmark.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@
},
"outputs": [],
"source": [
"results.upload_to_hub(owner=\"my-username\", access=\"public\")"
"results.upload_to_hub(owner=\"my-username\")"
]
},
{
Expand Down
2 changes: 2 additions & 0 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ dependencies:
# Hub client
- authlib
- httpx
- requests
- aiohttp

# Scientific
- numpy < 3
Expand Down
11 changes: 6 additions & 5 deletions polaris/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,17 +171,18 @@ def zarr_root(self) -> zarr.Group | None:
See also `dataset.load_to_memory()`.
"""

from polaris.hub.storage import StorageSession

if self._zarr_root is not None:
return self._zarr_root

if self.zarr_root_path is None:
return None

saved_on_hub = self.zarr_root_path.startswith(StorageSession.polaris_protocol)
fs, _ = fsspec.url_to_fs(self.zarr_root_path)
remote = not (
fs.protocol == "file" or (isinstance(fs.protocol, (list, tuple)) and fs.protocol[0] == "file")
)

if self._warn_about_remote_zarr and saved_on_hub:
if self._warn_about_remote_zarr and remote:
# TODO (cwognum): The user now has no easy way of knowing whether the dataset is "small enough".
logger.warning(
f"You're loading data from a remote location. "
Expand All @@ -191,7 +192,7 @@ def zarr_root(self) -> zarr.Group | None:
self._warn_about_remote_zarr = False

try:
if saved_on_hub:
if remote:
self._zarr_root = self.load_zarr_root_from_hub()
else:
self._zarr_root = self.load_zarr_root_from_local()
Expand Down
9 changes: 3 additions & 6 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

import fsspec
import numpy as np
import pandas as pd
import zarr
import pandas as pd
from datamol.utils import fs as dmfs
from pydantic import PrivateAttr, computed_field, field_validator, model_validator
from typing_extensions import Self, deprecated
Expand Down Expand Up @@ -138,12 +138,9 @@ def load_zarr_root_from_hub(self):
"""
Loads a Zarr archive from the Hub.
"""
from polaris.hub.client import PolarisHubClient
from polaris.hub.storage import StorageSession

with PolarisHubClient() as client:
with StorageSession(client, "read", self.urn) as storage:
return zarr.open_consolidated(store=storage.store("extension"))
store = fsspec.get_mapper(self.zarr_root_path)
return zarr.open_consolidated(store=store)

@computed_field
@property
Expand Down
8 changes: 2 additions & 6 deletions polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,8 @@ def load_zarr_root_from_hub(self):
"""
Loads a Zarr archive from the Hub.
"""
from polaris.hub.client import PolarisHubClient
from polaris.hub.storage import StorageSession

with PolarisHubClient() as client:
with StorageSession(client, "read", self.urn) as storage:
return zarr.open_consolidated(store=storage.store("root"))
store = fsspec.get_mapper(self.zarr_root_path)
return zarr.open_consolidated(store=store)

@property
def zarr_manifest_path(self) -> str:
Expand Down
97 changes: 49 additions & 48 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
from io import BytesIO
from urllib.parse import urljoin

import httpx
Expand All @@ -11,6 +10,7 @@
from authlib.oauth2.rfc6749 import OAuth2Token
from httpx import ConnectError, HTTPStatusError, Response
from typing_extensions import Self
import fsspec

from polaris.benchmark import (
BenchmarkV1Specification,
Expand All @@ -20,12 +20,11 @@
from polaris.benchmark._benchmark_v2 import BenchmarkV2Specification
from polaris.competition import CompetitionSpecification
from polaris.model import Model
from polaris.dataset import Dataset, DatasetV1, DatasetV2
from polaris.dataset import DatasetV1, DatasetV2
from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions
from polaris.hub.external_client import ExternalAuthClient
from polaris.hub.oauth import CachedTokenAuth
from polaris.hub.settings import PolarisHubSettings
from polaris.hub.storage import StorageSession
from polaris.utils.context import track_progress
from polaris.utils.errors import (
PolarisCreateArtifactError,
Expand Down Expand Up @@ -119,12 +118,7 @@ def __init__(
)

def __enter__(self: Self) -> Self:
"""
When used as a context manager, automatically check that authentication is valid.
"""
super().__enter__()
if not self.ensure_active_token():
raise PolarisUnauthorizedError()
return self

@property
Expand Down Expand Up @@ -184,10 +178,13 @@ def fetch_token(self, **kwargs):
message=f"Could not obtain a token to access the Hub. Error was: {error.error} - {error.description}"
) from error

def _base_request_to_hub(self, url: str, method: str, **kwargs) -> Response:
"""Utility function since most API methods follow the same pattern"""
def _base_request_to_hub(self, url: str, method: str, withhold_token: bool, **kwargs) -> Response:
"""Utility function for making requests to the Hub with consistent error handling."""
if not withhold_token:
if not self.ensure_active_token():
raise PolarisUnauthorizedError()
try:
response = self.request(url=url, method=method, **kwargs)
response = self.request(url=url, method=method, withhold_token=withhold_token, **kwargs)
response.raise_for_status()
return response
except HTTPStatusError as error:
Expand Down Expand Up @@ -263,9 +260,11 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
A list of dataset names in the format `owner/dataset_slug`.
"""
with track_progress(description="Fetching datasets", total=1):
# Step 1: Fetch enough v2 datasets to cover the offset and limit
v2_json_response = self._base_request_to_hub(
url="/v2/dataset", method="GET", params={"limit": limit, "offset": offset}
url="/v2/dataset",
method="GET",
withhold_token=True,
params={"limit": limit, "offset": offset},
).json()
v2_data = v2_json_response["data"]
v2_datasets = [dataset["artifactId"] for dataset in v2_data]
Expand All @@ -280,6 +279,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
v1_json_response = self._base_request_to_hub(
url="/v1/dataset",
method="GET",
withhold_token=True,
params={
"limit": remaining_limit,
"offset": max(0, offset - v2_json_response["metadata"]["total"]),
Expand Down Expand Up @@ -335,23 +335,18 @@ def _get_v1_dataset(
A `Dataset` instance, if it exists.
"""
url = f"/v1/dataset/{owner}/{slug}"
response = self._base_request_to_hub(url=url, method="GET")
response = self._base_request_to_hub(url=url, method="GET", withhold_token=True)
response_data = response.json()

# Disregard the Zarr root in the response. We'll get it from the storage token instead.
response_data.pop("zarrRootPath", None)

# Load the dataset table and optional Zarr archive
with StorageSession(self, "read", Dataset.urn_for(owner, slug)) as storage:
table = pd.read_parquet(BytesIO(storage.get_file("root")))
zarr_root_path = storage.paths.extension
root_url = response_data.get("root")
extension_url = response_data.get("extension")

if zarr_root_path is not None:
# For V1 datasets, the Zarr Root is optional.
# It should be None if the dataset does not use pointer columns
zarr_root_path = str(zarr_root_path)
# Load the dataset table and optional Zarr archive
with fsspec.open(root_url, mode="rb") as f:
table = pd.read_parquet(f)

dataset = DatasetV1(table=table, zarr_root_path=zarr_root_path, **response_data)
dataset = DatasetV1(table=table, zarr_root_path=extension_url, **response_data)
md5sum = response_data["md5Sum"]

if dataset.should_verify_checksum(verify_checksum):
Expand All @@ -364,17 +359,14 @@ def _get_v1_dataset(
def _get_v2_dataset(self, owner: str | HubOwner, slug: str) -> DatasetV2:
""""""
url = f"/v2/dataset/{owner}/{slug}"
response = self._base_request_to_hub(url=url, method="GET")
response = self._base_request_to_hub(url=url, method="GET", withhold_token=True)
response_data = response.json()

# Disregard the Zarr root in the response. We'll get it from the storage token instead.
response_data.pop("zarrRootPath", None)

# Load the Zarr archive
with StorageSession(self, "read", DatasetV2.urn_for(owner, slug)) as storage:
zarr_root_path = str(storage.paths.root)
root_url = response_data.get("root")

dataset = DatasetV2(zarr_root_path=zarr_root_path, **response_data)
# For v2 datasets, the zarr_path always exists
dataset = DatasetV2(zarr_root_path=root_url, **response_data)
return dataset

def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:
Expand All @@ -391,7 +383,10 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:
with track_progress(description="Fetching benchmarks", total=1):
# Step 1: Fetch enough v2 benchmarks to cover the offset and limit
v2_json_response = self._base_request_to_hub(
url="/v2/benchmark", method="GET", params={"limit": limit, "offset": offset}
url="/v2/benchmark",
method="GET",
withhold_token=True,
params={"limit": limit, "offset": offset},
).json()
v2_data = v2_json_response["data"]
v2_benchmarks = [benchmark["artifactId"] for benchmark in v2_data]
Expand All @@ -402,10 +397,10 @@ def list_benchmarks(self, limit: int = 100, offset: int = 0) -> list[str]:

# Step 2: Calculate the remaining limit and fetch v1 benchmarks
remaining_limit = max(0, limit - len(v2_benchmarks))

v1_json_response = self._base_request_to_hub(
url="/v1/benchmark",
method="GET",
withhold_token=True,
params={
"limit": remaining_limit,
"offset": max(0, offset - v2_json_response["metadata"]["total"]),
Expand Down Expand Up @@ -449,7 +444,9 @@ def _get_v1_benchmark(
slug: str,
verify_checksum: ChecksumStrategy = "verify_unless_zarr",
) -> BenchmarkV1Specification:
response = self._base_request_to_hub(url=f"/v1/benchmark/{owner}/{slug}", method="GET")
response = self._base_request_to_hub(
url=f"/v1/benchmark/{owner}/{slug}", method="GET", withhold_token=True
)
response_data = response.json()

# TODO (jstlaurent): response["dataset"]["artifactId"] is the owner/name unique identifier,
Expand Down Expand Up @@ -478,14 +475,17 @@ def _get_v1_benchmark(
return benchmark

def _get_v2_benchmark(self, owner: str | HubOwner, slug: str) -> BenchmarkV2Specification:
response = self._base_request_to_hub(url=f"/v2/benchmark/{owner}/{slug}", method="GET")
response = self._base_request_to_hub(
url=f"/v2/benchmark/{owner}/{slug}", method="GET", withhold_token=True
)
response_data = response.json()

response_data["dataset"] = self.get_dataset(*response_data["dataset"]["artifactId"].split("/"))

# Load the split index sets
with StorageSession(self, "read", BenchmarkV2Specification.urn_for(owner, slug)) as storage:
split = {label: storage.get_file(label) for label in response_data.get("split", {}).keys()}
split = {}
for label, url in response_data.get("split", {}).items():
with fsspec.open(url, mode="rb") as f:
split[label] = f.read()

return BenchmarkV2Specification(**{**response_data, "split": split})

Expand Down Expand Up @@ -517,7 +517,9 @@ def upload_results(
result_json = results.model_dump(by_alias=True, exclude_none=True)

# Make a request to the Hub
response = self._base_request_to_hub(url="/v2/result", method="POST", json=result_json)
response = self._base_request_to_hub(
url="/v2/result", method="POST", withhold_token=False, json=result_json
)

# Inform the user about where to find their newly created artifact.
result_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location"))
Expand Down Expand Up @@ -607,15 +609,12 @@ def get_competition(self, artifact_id: str) -> CompetitionSpecification:
A `CompetitionSpecification` instance, if it exists.
"""
url = f"/v1/competition/{artifact_id}"
response = self._base_request_to_hub(url=url, method="GET")
response = self._base_request_to_hub(url=url, method="GET", withhold_token=True)
response_data = response.json()

with StorageSession(
self, "read", CompetitionSpecification.urn_for(*artifact_id.split("/"))
) as storage:
zarr_root_path = str(storage.paths.root)
root_url = response_data.get("root")

return CompetitionSpecification(zarr_root_path=zarr_root_path, **response_data)
return CompetitionSpecification(zarr_root_path=root_url, **response_data)

def submit_competition_predictions(
self,
Expand All @@ -641,6 +640,7 @@ def submit_competition_predictions(
response = self._base_request_to_hub(
url="/v1/competition-prediction",
method="POST",
withhold_token=False,
json=prediction_payload,
)
return response
Expand All @@ -657,15 +657,15 @@ def list_models(self, limit: int = 100, offset: int = 0) -> list[str]:
"""
with track_progress(description="Fetching models", total=1):
json_response = self._base_request_to_hub(
url="/v2/model", method="GET", params={"limit": limit, "offset": offset}
url="/v2/model", method="GET", withhold_token=True, params={"limit": limit, "offset": offset}
).json()
models = [model["artifactId"] for model in json_response["data"]]

return models

def get_model(self, artifact_id: str) -> Model:
url = f"/v2/model/{artifact_id}"
response = self._base_request_to_hub(url=url, method="GET")
response = self._base_request_to_hub(url=url, method="GET", withhold_token=True)
response_data = response.json()

return Model(**response_data)
Expand Down Expand Up @@ -703,6 +703,7 @@ def upload_model(
response = self._base_request_to_hub(
url=url,
method="PUT",
withhold_token=False,
json={"parentArtifactId": parent_artifact_id, **model_json},
)

Expand Down
Loading
Loading