\n",
"
Note
\n",
@@ -411,17 +413,10 @@
"id": "3c7c11ac",
"metadata": {},
"outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "A\n"
- ]
- },
{
"data": {
"text/plain": [
- "'/home/cas/.cache/polaris-tutorials/002/zarr/data.zarr//images#0'"
+ "'images#0'"
]
},
"execution_count": 13,
@@ -531,7 +526,7 @@
{
"data": {
"text/html": [
- "
| name | None |
|---|
| description | |
|---|
| tags | |
|---|
| user_attributes | |
|---|
| owner | None |
|---|
| default_adapters | None |
|---|
| md5sum | 6ef8d23737aafcbf82c421e7f99e1d95 |
|---|
| readme | |
|---|
| annotations | | images | | is_pointer | True |
|---|
| modality | UNKNOWN |
|---|
| description | None |
|---|
| user_attributes | |
|---|
| dtype | object |
|---|
|
|---|
|
|---|
| source | None |
|---|
| license | None |
|---|
| curation_reference | None |
|---|
| cache_dir | /home/cas/.cache/polaris/datasets/None/6ef8d23737aafcbf82c421e7f99e1d95 |
|---|
| artifact_id | None |
|---|
| n_rows | 1000 |
|---|
| n_columns | 1 |
|---|
"
+ "
| name | None |
|---|
| description | |
|---|
| tags | |
|---|
| user_attributes | |
|---|
| owner | None |
|---|
| default_adapters | |
|---|
| zarr_root_path | /home/cas/.cache/polaris-tutorials/002/json/data.zarr |
|---|
| md5sum | 5488b4909fd67d3208624288e720e1b8 |
|---|
| readme | |
|---|
| annotations | | images | | is_pointer | True |
|---|
| modality | UNKNOWN |
|---|
| description | None |
|---|
| user_attributes | |
|---|
| dtype | object |
|---|
|
|---|
|
|---|
| source | None |
|---|
| license | None |
|---|
| curation_reference | None |
|---|
| cache_dir | /home/cas/.cache/polaris/datasets/None/5488b4909fd67d3208624288e720e1b8 |
|---|
| artifact_id | None |
|---|
| n_rows | 1000 |
|---|
| n_columns | 1 |
|---|
"
],
"text/plain": [
"{\n",
@@ -540,8 +535,9 @@
" \"tags\": [],\n",
" \"user_attributes\": {},\n",
" \"owner\": null,\n",
- " \"default_adapters\": null,\n",
- " \"md5sum\": \"6ef8d23737aafcbf82c421e7f99e1d95\",\n",
+ " \"default_adapters\": {},\n",
+ " \"zarr_root_path\": \"/home/cas/.cache/polaris-tutorials/002/json/data.zarr\",\n",
+ " \"md5sum\": \"5488b4909fd67d3208624288e720e1b8\",\n",
" \"readme\": \"\",\n",
" \"annotations\": {\n",
" \"images\": {\n",
@@ -555,7 +551,7 @@
" \"source\": null,\n",
" \"license\": null,\n",
" \"curation_reference\": null,\n",
- " \"cache_dir\": \"/home/cas/.cache/polaris/datasets/None/6ef8d23737aafcbf82c421e7f99e1d95\",\n",
+ " \"cache_dir\": \"/home/cas/.cache/polaris/datasets/None/5488b4909fd67d3208624288e720e1b8\",\n",
" \"artifact_id\": null,\n",
" \"n_rows\": 1000,\n",
" \"n_columns\": 1\n",
diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py
index b3170a52..0d4f12ed 100644
--- a/polaris/benchmark/_base.py
+++ b/polaris/benchmark/_base.py
@@ -6,6 +6,7 @@
import fsspec
import numpy as np
import pandas as pd
+from datamol.utils import fs
from pydantic import (
Field,
FieldValidationInfo,
@@ -20,7 +21,6 @@
from polaris.dataset import Dataset, Subset
from polaris.evaluate import BenchmarkResults, Metric, ResultsType
from polaris.hub.settings import PolarisHubSettings
-from polaris.utils import fs
from polaris.utils.context import tmp_attribute_change
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError
diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py
index e22f1ba8..5200b2de 100644
--- a/polaris/dataset/_dataset.py
+++ b/polaris/dataset/_dataset.py
@@ -1,6 +1,4 @@
import json
-import os.path
-from collections import defaultdict
from hashlib import md5
from typing import Dict, List, Optional, Tuple, Union
@@ -8,9 +6,11 @@
import numpy as np
import pandas as pd
import zarr
+from datamol.utils import fs
from loguru import logger
from pydantic import (
Field,
+ PrivateAttr,
computed_field,
field_serializer,
field_validator,
@@ -20,12 +20,10 @@
from polaris._artifact import BaseArtifactModel
from polaris.dataset._adapters import Adapter
from polaris.dataset._column import ColumnAnnotation
-from polaris.hub.settings import PolarisHubSettings
-from polaris.utils import fs
+from polaris.hub.polarisfs import PolarisFileSystem
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError
-from polaris.utils.io import get_zarr_root, robust_copy
from polaris.utils.types import AccessType, HttpUrlString, HubOwner, License
# Constants
@@ -51,6 +49,7 @@ class Dataset(BaseArtifactModel):
path to a `.parquet` file or a `pandas.DataFrame`.
default_adapters: The adapters that the Dataset recommends to use by default to change the format of the data
for specific columns.
+ zarr_root_path: The data for any pointer column should be saved in the Zarr archive this path points to.
md5sum: The checksum is used to verify the version of the dataset specification. If specified, it will
raise an error if the specified checksum doesn't match the computed checksum.
readme: Markdown text that can be used to provide a formatted description of the dataset.
@@ -72,6 +71,7 @@ class Dataset(BaseArtifactModel):
# Data
table: Union[pd.DataFrame, str]
default_adapters: Dict[str, Adapter] = Field(default_factory=dict)
+ zarr_root_path: Optional[str] = None
md5sum: Optional[str] = None
# Additional meta-data
@@ -85,7 +85,8 @@ class Dataset(BaseArtifactModel):
cache_dir: Optional[str] = None # Where to cache the data to if cache() is called.
# Private attributes
- _path_to_hash: Dict[str, Dict[str, str]] = defaultdict(dict)
+ _zarr_root: Optional[zarr.Group] = PrivateAttr(None)
+ _client = PrivateAttr(None) # Optional[PolarisHubClient]
_has_been_warned: bool = False
_has_been_cached: bool = False
@@ -188,6 +189,42 @@ def _compute_checksum(table):
checksum = hash_fn.hexdigest()
return checksum
+ @property
+ def client(self):
+ """The Polaris Hub client used to interact with the Polaris Hub."""
+
+ # Import it here to prevent circular imports
+ from polaris.hub.client import PolarisHubClient
+
+ if self._client is None:
+ self._client = PolarisHubClient()
+ return self._client
+
+ @property
+ def zarr_root(self):
+ """Open the zarr archive in read-write mode if it is not already open."""
+ if self.zarr_root_path is None or not any(anno.is_pointer for anno in self.annotations.values()):
+ return None
+
+ saved_on_hub = PolarisFileSystem.is_polarisfs_path(self.zarr_root_path)
+ saved_remote = saved_on_hub or not fs.is_local_path(self.zarr_root_path)
+
+ if saved_remote and not self._has_been_warned:
+ logger.warning(
+ f"You're loading data from a remote location. "
+ f"To speed up this process, consider caching the dataset first "
+ f"using {self.__class__.__name__}.cache()"
+ )
+ self._has_been_warned = True
+
+ # We open the archive in read-only mode if it is saved on the Hub
+ if self._zarr_root is None:
+ if saved_on_hub:
+ self._zarr_root = self.client.open_zarr_file(self.owner, self.name, self.zarr_root_path, "r+")
+ else:
+ self._zarr_root = zarr.open(self.zarr_root_path, "r+")
+ return self._zarr_root
+
@computed_field
@property
def n_rows(self) -> int:
@@ -228,64 +265,34 @@ def get_data(self, row: int, col: str, adapters: Optional[List[Adapter]] = None)
adapters = adapters or self.default_adapters
- def _load(p: str, index: Union[int, slice]) -> np.ndarray:
- """Tiny helper function to reduce code repetition."""
- arr = zarr.open(p, mode="r")
- arr = arr[index]
-
- if isinstance(index, slice):
- arr = tuple(arr)
-
- adapter = adapters.get(col)
- if adapter is not None:
- arr = adapter(arr)
-
- return arr
-
+ # If not a pointer, we can just return here
value = self.table.loc[row, col]
if not self.annotations[col].is_pointer:
return value
- value, index = self._split_index_from_path(value)
+ # Load the data from the Zarr archive
+ path, index = self._split_index_from_path(value)
+ arr = self.zarr_root[path][index]
- # In the case it is a pointer column, we need to load additional data into memory
- # We first check if the data has been downloaded to the cache.
- path = self._get_cache_path(column=col, value=value)
- if fs.exists(path):
- return _load(path, index)
+ # Change to tuple if a slice
+ if isinstance(index, slice):
+ arr = tuple(arr)
- # If it doesn't exist, we load from the original path and warn if not local
- if not fs.is_local_path(value) and not self._has_been_warned:
- logger.warning(
- f"You're loading data from a remote location. "
- f"To speed up this process, consider caching the dataset first "
- f"using {self.__class__.__name__}.cache()"
- )
- self._has_been_warned = True
- return _load(value, index)
+ # Adapt the input
+ adapter = adapters.get(col)
+ if adapter is not None:
+ arr = adapter(arr)
+
+ return arr
def upload_to_hub(
- self,
- env_file: Optional[Union[str, os.PathLike]] = None,
- settings: Optional[PolarisHubSettings] = None,
- cache_auth_token: bool = True,
- access: Optional[AccessType] = "private",
- owner: Optional[Union[HubOwner, str]] = None,
- **kwargs: dict,
+ self, access: Optional[AccessType] = "private", owner: Optional[Union[HubOwner, str]] = None
):
"""
Very light, convenient wrapper around the
[`PolarisHubClient.upload_dataset`][polaris.hub.client.PolarisHubClient.upload_dataset] method.
"""
- from polaris.hub.client import PolarisHubClient
-
- with PolarisHubClient(
- env_file=env_file,
- settings=settings,
- cache_auth_token=cache_auth_token,
- **kwargs,
- ) as client:
- return client.upload_dataset(self, access=access, owner=owner)
+ self.client.upload_dataset(self, access=access, owner=owner)
@classmethod
def from_json(cls, path: str):
@@ -323,19 +330,19 @@ def to_json(self, destination: str) -> str:
fs.mkdir(destination, exist_ok=True)
table_path = fs.join(destination, "table.parquet")
dataset_path = fs.join(destination, "dataset.json")
- pointer_dir = fs.join(destination, "data")
-
- # Save additional data
- new_table = self._copy_and_update_pointers(pointer_dir, inplace=False)
+ zarr_archive = fs.join(destination, "data.zarr")
# Lu: Avoid serilizing and sending None to hub app.
serialized = self.model_dump(exclude={"cache_dir"}, exclude_none=True)
serialized["table"] = table_path
- # We need to recompute the checksum, as the pointer paths have changed
- serialized["md5sum"] = self._compute_checksum(new_table)
+ # Copy over Zarr data to the destination
+ if self.zarr_root is not None:
+ dest = zarr.open(zarr_archive, "w")
+ zarr.copy_all(source=self.zarr_root, dest=dest)
+ serialized["zarr_root_path"] = zarr_archive
- new_table.to_parquet(table_path)
+ self.table.to_parquet(table_path)
with fsspec.open(dataset_path, "w") as f:
json.dump(serialized, f)
@@ -355,32 +362,15 @@ def cache(self, cache_dir: Optional[str] = None) -> str:
if cache_dir is not None:
self.cache_dir = cache_dir
+ self.to_json(self.cache_dir)
+
+ if self.zarr_root_path is not None:
+ self.zarr_root_path = fs.join(self.cache_dir, "data.zarr")
+
if not self._has_been_cached:
- self._copy_and_update_pointers(self.cache_dir, inplace=True)
self._has_been_cached = True
return self.cache_dir
- def _get_cache_path(self, column: str, value: str) -> Optional[str]:
- """
- Returns where the data _would be_ cached for any entry in the pointer columns,
- or None if the column is not a pointer column.
- """
- if not self.annotations[column].is_pointer:
- return
-
- if value not in self._path_to_hash[column]:
- h = md5(value.encode("utf-8")).hexdigest()
-
- value, _ = self._split_index_from_path(value)
- ext = fs.get_extension(value)
- dst = fs.join(self.cache_dir, column, f"{h}.{ext}")
-
- # The reason for caching the path is to speed-up retrieval. Hashing can be slow and with large
- # datasets this could become a bottleneck.
- self._path_to_hash[column][value] = dst
-
- return self._path_to_hash[column][value]
-
def size(self):
return self.rows, self.n_columns
@@ -402,39 +392,6 @@ def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]:
raise ValueError(f"Invalid index format: {index}")
return path, index
- def _copy_and_update_pointers(
- self, save_dir: str, table: Optional[pd.DataFrame] = None, inplace: bool = False
- ) -> pd.DataFrame:
- """Copy and update the path in the table to the new destination"""
-
- def fn(path):
- """Helper function that can be used within Pandas apply to copy and update all files"""
-
- # We copy the entire .zarr hierarchy
- root = get_zarr_root(path)
- if root is None:
- raise NotImplementedError(
- "Only the .zarr file format is currently supported for pointer columns"
- )
-
- # We could introduce name collisions here and thus use a hash of the original path for the destination
- dst = fs.join(save_dir, f"{md5(root.encode('utf-8')).hexdigest()}.zarr")
- robust_copy(root, dst)
-
- diff = os.path.relpath(path, root)
- dst = fs.join(dst, diff)
- return dst
-
- if table is None:
- table = self.table
- if not inplace:
- table = self.table.copy(deep=True)
-
- for c in table.columns:
- if self.annotations[c].is_pointer:
- table[c] = table[c].apply(fn)
- return table
-
def __getitem__(self, item):
"""Allows for indexing the dataset directly"""
ret = self.table.loc[item]
@@ -486,3 +443,8 @@ def __eq__(self, other):
if not isinstance(other, Dataset):
return False
return self.md5sum == other.md5sum
+
+ def __del__(self):
+ """Close the connection of the client"""
+ if self._client is not None:
+ self._client.close()
diff --git a/polaris/dataset/_factory.py b/polaris/dataset/_factory.py
index aff4ba93..b6dd48e3 100644
--- a/polaris/dataset/_factory.py
+++ b/polaris/dataset/_factory.py
@@ -89,7 +89,11 @@ def zarr_root(self) -> zarr.Group:
All data for a single dataset is expected to be stored in the same Zarr archive.
"""
if self._zarr_root is None:
- self._zarr_root = zarr.open(self.zarr_root_path, "w")
+ # NOTE (cwognum): The DirectoryStore is the default store when calling zarr.open
+ # I nevertheless explicitly set it here to make it clear that this is a design decision.
+ # We could consider using different stores, such as the NestedDirectoryStore.
+ store = zarr.DirectoryStore(self.zarr_root_path)
+ self._zarr_root = zarr.open(store, "w")
if not isinstance(self._zarr_root, zarr.Group):
raise ValueError("The root of the zarr hierarchy should be a group")
return self._zarr_root
@@ -215,6 +219,7 @@ def build(self) -> Dataset:
table=self._table,
annotations=self._annotations,
default_adapters=self._adapters,
+ zarr_root_path=self.zarr_root_path,
)
def reset(self, zarr_root_path: Optional[str] = None):
diff --git a/polaris/dataset/converters/_base.py b/polaris/dataset/converters/_base.py
index 8e3c64af..5ca31b37 100644
--- a/polaris/dataset/converters/_base.py
+++ b/polaris/dataset/converters/_base.py
@@ -17,12 +17,11 @@ def convert(self, path: str) -> FactoryProduct:
raise NotImplementedError
@staticmethod
- def get_pointer(root: str, column: str, index: Union[int, slice]) -> str:
+ def get_pointer(column: str, index: Union[int, slice]) -> str:
"""
Creates a pointer.
Args:
- root: The root path of the zarr hierarchy.
column: The name of the column. Each column has its own group in the root.
index: The index or slice of the pointer.
"""
@@ -30,4 +29,4 @@ def get_pointer(root: str, column: str, index: Union[int, slice]) -> str:
index_substr = f"{_INDEX_SEP}{index.start}:{index.stop}"
else:
index_substr = f"{_INDEX_SEP}{index}"
- return f"{root}/{column}{index_substr}"
+ return f"{column}{index_substr}"
diff --git a/polaris/dataset/converters/_sdf.py b/polaris/dataset/converters/_sdf.py
index 76eab5cc..f78c072b 100644
--- a/polaris/dataset/converters/_sdf.py
+++ b/polaris/dataset/converters/_sdf.py
@@ -123,7 +123,7 @@ def _get_name(mol: dm.Mol):
# Get the pointer path
pointer_idx = f"{start}:{end}" if start != end else f"{start}"
- pointer = self.get_pointer(factory.zarr_root_path, self.mol_column, pointer_idx)
+ pointer = self.get_pointer(self.mol_column, pointer_idx)
# Get the single unique value per column for the group and append
unique_values = [group[col].unique()[0] for col in df.columns]
@@ -132,7 +132,7 @@ def _get_name(mol: dm.Mol):
df = grouped
else:
- pointers = [self.get_pointer(factory.zarr_root_path, self.mol_column, i) for i in range(len(df))]
+ pointers = [self.get_pointer(self.mol_column, i) for i in range(len(df))]
df[self.mol_column] = pd.Series(pointers)
# Set the annotations
diff --git a/polaris/dataset/converters/_zarr.py b/polaris/dataset/converters/_zarr.py
index 7d26ea61..5ed706d0 100644
--- a/polaris/dataset/converters/_zarr.py
+++ b/polaris/dataset/converters/_zarr.py
@@ -34,17 +34,15 @@ def convert(self, path: str, factory: "DatasetFactory") -> FactoryProduct:
if v is not None:
raise ValueError("The root of the zarr hierarchy should only contain arrays.")
+ # Copy to the source zarr, so everything is in one place
+ zarr.copy_all(source=src, dest=factory.zarr_root)
+
# Construct the table
# Parse any group into a column
data = defaultdict(dict)
for col, arr in src.arrays():
- # Copy to the source zarr, so everything is in one place
- dst = zarr.open_group("/".join([factory.zarr_root_path, col]), "w")
- zarr.copy(arr, dst)
-
for i in range(len(arr)):
- # In case all data is saved in a single array, we construct a path with an index suffix.
- data[col][i] = self.get_pointer(path, arr.name, i)
+ data[col][i] = self.get_pointer(arr.name.removeprefix("/"), i)
# Construct the dataset
table = pd.DataFrame(data)
diff --git a/polaris/hub/client.py b/polaris/hub/client.py
index 2aff2cb7..9f7e2292 100644
--- a/polaris/hub/client.py
+++ b/polaris/hub/client.py
@@ -16,6 +16,7 @@
from authlib.integrations.base_client.errors import InvalidTokenError, MissingTokenError
from authlib.integrations.httpx_client import OAuth2Client, OAuthError
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
+from datamol.utils import fs
from httpx import HTTPStatusError
from httpx._types import HeaderTypes, URLTypes
from loguru import logger
@@ -29,8 +30,8 @@
from polaris.evaluate import BenchmarkResults
from polaris.hub.polarisfs import PolarisFileSystem
from polaris.hub.settings import PolarisHubSettings
-from polaris.utils import fs
from polaris.utils.constants import DEFAULT_CACHE_DIR
+from polaris.utils.context import tmp_attribute_change
from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError
from polaris.utils.types import AccessType, HubOwner, IOMode, TimeoutTypes
@@ -152,6 +153,25 @@ def _base_request_to_hub(self, url: str, method: str, **kwargs):
return response
+ def _normalize_owner(
+ self,
+ artifact_owner: Optional[Union[str, HubOwner]] = None,
+ parameter_owner: Optional[Union[str, HubOwner]] = None,
+ ) -> HubOwner:
+ """
+ Normalize the owner of an artifact to a `HubOwner` instance.
+ The parameter owner takes precedence over the artifact owner.
+ """
+ if parameter_owner is not None:
+ artifact_owner = parameter_owner
+
+ if artifact_owner is None:
+ raise ValueError(
+ "Either specify the `owner` attribute for the artifact or pass the `owner` parameter."
+ )
+
+ return artifact_owner if isinstance(artifact_owner, HubOwner) else HubOwner(slug=artifact_owner)
+
# =========================
# Overrides
# =========================
@@ -360,6 +380,7 @@ def open_zarr_file(
try:
store = zarr.storage.FSStore(path, fs=polaris_fs)
return zarr.open(store, mode=mode)
+
except Exception as e:
raise PolarisHubError("Error opening Zarr store") from e
@@ -443,19 +464,11 @@ def upload_results(
Args:
results: The results to upload.
access: Grant public or private access to result
- owner: Which Hub user or organization owns the artifact.
- Optional if and only if the `benchmark.owner` attribute is set.
+ owner: Which Hub user or organization owns the artifact. Takes precedence over `results.owner`.
"""
# Get the serialized model data-structure
-
- if results.owner is None:
- if owner is None:
- raise ValueError(
- "The `owner` argument must be specified if the `results.owner` attribute is not set."
- )
- results.owner = owner if isinstance(owner, HubOwner) else HubOwner(slug=owner)
-
+ results.owner = self._normalize_owner(results.owner, owner)
result_json = results.model_dump(by_alias=True, exclude_none=True)
# Make a request to the hub
@@ -498,24 +511,24 @@ def upload_dataset(
tuple with (connect_timeout, write_timeout). The type of the the timout parameter comes from `httpx`.
Since datasets can get large, it might be needed to increase the write timeout for larger datasets.
See also: https://www.python-httpx.org/advanced/#timeout-configuration
- owner: Which Hub user or organization owns the artifact.
- Optional if and only if the `benchmark.owner` attribute is set.
+ owner: Which Hub user or organization owns the artifact. Takes precedence over `dataset.owner`.
"""
-
- if dataset.owner is None:
- if owner is None:
- raise ValueError(
- "The `owner` argument must be specified if the `dataset.owner` attribute is not set."
- )
- dataset.owner = owner if isinstance(owner, HubOwner) else HubOwner(slug=owner)
+ # Normalize timeout
+ if timeout is None:
+ timeout = self.settings.default_timeout
# Get the serialized data-model
- # We exclude the table as it handled separately and the cache_dir as it is user-specific
+ # We exclude the table as it handled separately and we exclude the cache_dir as it is user-specific
+ dataset.owner = self._normalize_owner(dataset.owner, owner)
dataset_json = dataset.model_dump(exclude={"cache_dir", "table"}, exclude_none=True, by_alias=True)
- # Uploading a dataset is a two-step process.
+ # We will save the Zarr archive to the Hub as well
+ dataset_json["zarrRootPath"] = f"{PolarisFileSystem.protocol}://data.zarr"
+
+ # Uploading a dataset is a three-step process.
# 1. Upload the dataset meta data to the hub and prepare the hub to receive the parquet file
# 2. Upload the parquet file to the hub
+ # 3. Upload the associated Zarr archive
# TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub?
# Write the parquet file directly to a buffer
@@ -539,6 +552,7 @@ def upload_dataset(
"access": access,
**dataset_json,
},
+ timeout=timeout,
)
# Step 2: Upload the parquet file
@@ -549,6 +563,7 @@ def upload_dataset(
headers={
"Content-type": "application/vnd.apache.parquet",
},
+ timeout=timeout,
)
if hub_response.status_code == 307:
@@ -568,6 +583,20 @@ def upload_dataset(
else:
hub_response.raise_for_status()
+ # Step 3: Upload any associated Zarr archive
+ if dataset.zarr_root is not None:
+ with tmp_attribute_change(self.settings, "default_timeout", timeout):
+ # Copy the Zarr archive to the hub
+ # This does not copy the consolidated data
+ dest = self.open_zarr_file(
+ owner=dataset.owner,
+ name=dataset.name,
+ path=dataset_json["zarrRootPath"],
+ mode="w",
+ )
+ logger.info("Copying Zarr archive to the Hub. This may take a while.")
+ zarr.copy_all(source=dataset.zarr_root, dest=dest, log=logger.info)
+
logger.success(
"Your dataset has been successfully uploaded to the Hub. "
f"View it here: {urljoin(self.settings.hub_url, f'datasets/{dataset.owner}/{dataset.name}')}"
@@ -600,18 +629,11 @@ def upload_benchmark(
Args:
benchmark: The benchmark to upload.
access: Grant public or private access to result
- owner: Which Hub user or organization owns the artifact.
- Optional if and only if the `benchmark.owner` attribute is set.
+ owner: Which Hub user or organization owns the artifact. Takes precedence over `benchmark.owner`.
"""
- if benchmark.owner is None:
- if owner is None:
- raise ValueError(
- "The `owner` argument must be specified if the `benchmark.owner` attribute is not set."
- )
- benchmark.owner = owner if isinstance(owner, HubOwner) else HubOwner(slug=owner)
-
# Get the serialized data-model
# We exclude the dataset as we expect it to exist on the hub already.
+ benchmark.owner = self._normalize_owner(benchmark.owner, owner)
benchmark_json = benchmark.model_dump(exclude={"dataset"}, exclude_none=True, by_alias=True)
benchmark_json["datasetArtifactId"] = benchmark.dataset.artifact_id
benchmark_json["access"] = access
diff --git a/polaris/hub/polarisfs.py b/polaris/hub/polarisfs.py
index 5412feed..8d0efa9d 100644
--- a/polaris/hub/polarisfs.py
+++ b/polaris/hub/polarisfs.py
@@ -1,7 +1,8 @@
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+import hashlib
from datetime import datetime, timezone
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
+
import fsspec
-import hashlib
from polaris.utils.errors import PolarisHubError
from polaris.utils.types import TimeoutTypes
@@ -54,6 +55,18 @@ def __init__(
self.prefix = f"dataset/{dataset_owner}/{dataset_name}/"
self.base_path = f"/storage/{self.prefix.rstrip('/')}"
+ @staticmethod
+ def is_polarisfs_path(path: str) -> bool:
+ """Check if the given path is a PolarisFS path.
+
+ Args:
+ path: The path to check.
+
+ Returns:
+ True if the path is a PolarisFS path; otherwise, False.
+ """
+ return path.startswith(f"{PolarisFileSystem.protocol}://")
+
def ls(
self,
path: str,
diff --git a/polaris/loader/load.py b/polaris/loader/load.py
index 4f4d10c8..5bebf291 100644
--- a/polaris/loader/load.py
+++ b/polaris/loader/load.py
@@ -1,6 +1,7 @@
import json
import fsspec
+from datamol.utils import fs
from polaris.benchmark._definitions import (
MultiTaskBenchmarkSpecification,
@@ -8,7 +9,6 @@
)
from polaris.dataset import Dataset, create_dataset_from_file
from polaris.hub.client import PolarisHubClient
-from polaris.utils import fs
def load_dataset(path: str, verify_checksum: bool = True) -> Dataset:
diff --git a/polaris/utils/fs.py b/polaris/utils/fs.py
deleted file mode 100644
index 13cb95c4..00000000
--- a/polaris/utils/fs.py
+++ /dev/null
@@ -1,395 +0,0 @@
-"""
-The `fs` module makes it easier to work with all type of path (the ones supported by `fsspec`).
-"""
-
-import hashlib
-import io
-import os
-import pathlib
-from typing import List, Optional, Union
-
-import fsspec
-import fsspec.utils
-from datamol.utils import parallelized
-
-
-def _import_tqdm():
- try:
- from tqdm.auto import tqdm
-
- return tqdm
- except ImportError:
- return None
-
-
-def get_mapper(path: Union[str, os.PathLike]):
- """Get the fsspec mapper.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- return fsspec.get_mapper(str(path))
-
-
-def get_basename(path: Union[str, os.PathLike]):
- """Get the basename of a file or a folder.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- path = str(path)
- mapper = get_mapper(path)
- clean_path = path.rstrip(mapper.fs.sep)
- return str(clean_path).split(mapper.fs.sep)[-1]
-
-
-def get_extension(path: Union[str, os.PathLike]):
- """Get the extension of a file.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- basename = get_basename(path)
- return basename.split(".")[-1]
-
-
-def exists(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]):
- """Check whether a file or a directory exists.
-
- Important: File-like object always exists.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- return is_file(path) or is_dir(path)
-
-
-def is_file(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]):
- """Check whether a file exists.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- if isinstance(path, fsspec.core.OpenFile):
- return path.fs.isfile(path.path)
-
- elif isinstance(path, (str, os.PathLike)):
- mapper = get_mapper(str(path))
- return mapper.fs.isfile(str(path))
-
- else:
- return False
-
-
-def is_dir(path: Union[str, os.PathLike, fsspec.core.OpenFile, io.IOBase]):
- """Check whether a file exists.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
- if isinstance(path, fsspec.core.OpenFile):
- return path.fs.isdir(path.path)
-
- elif isinstance(path, (str, os.PathLike)):
- mapper = get_mapper(str(path))
- return mapper.fs.isdir(str(path))
-
- else:
- return False
-
-
-def get_protocol(path: Union[str, os.PathLike], fs: Optional[fsspec.AbstractFileSystem] = None):
- """Return the name of the path protocol.
-
- Args:
- path: a path supported by `fsspec` such as local, s3, gcs, etc.
- """
-
- if fs is None:
- fs = get_mapper(path).fs
-
- protocol = fs.protocol # type: ignore
-
- if "s3" in protocol:
- return "s3"
- elif "gs" in protocol:
- return "gs"
- elif isinstance(protocol, (tuple, list)):
- return protocol[0]
- return protocol
-
-
-def is_local_path(path: Union[str, os.PathLike]):
- """Check whether a path is local."""
- return get_protocol(str(path)) == "file"
-
-
-def join(*paths: str):
- """Join paths together. The first element determine the
- filesystem to use (and so the separator.
-
- Args:
- *paths: a list of paths supported by `fsspec` such as local, s3, gcs, etc.
- """
- _paths = [str(path).rstrip("/") for path in paths]
- source_path = _paths[0]
- fs = get_mapper(source_path).fs
- full_path = fs.sep.join(_paths)
- return full_path
-
-
-def get_size(file: Union[str, os.PathLike, io.IOBase, fsspec.core.OpenFile]) -> Optional[int]:
- """Get the size of a file given its path. Return None if the
- size can't be retrieved.
- """
-
- if isinstance(file, io.IOBase) and hasattr(file, "name"):
- fs_local = fsspec.filesystem("file")
- file_size = fs_local.size(getattr(file, "name"))
-
- elif isinstance(file, (str, os.PathLike)):
- fs = get_mapper(str(file)).fs
- file_size = fs.size(str(file))
-
- elif isinstance(file, fsspec.core.OpenFile):
- file_size = file.fs.size(file.path)
-
- else:
- file_size = None
-
- return file_size
-
-
-def copy_file(
- source: Union[str, pathlib.Path, io.IOBase, fsspec.core.OpenFile],
- destination: Union[str, pathlib.Path, io.IOBase, fsspec.core.OpenFile],
- chunk_size: Optional[int] = None,
- force: bool = False,
- progress: bool = False,
- leave_progress: bool = True,
-):
- """Copy one file to another location across different filesystem (local, S3, GCS, etc).
-
- Args:
- source: path or file-like object to copy from.
- destination: path or file-like object to copy to.
- chunk_size: the chunk size to use. If progress is enabled the chunk
- size is `None`, it is set to 1MB (1024 * 1024).
- force: whether to overwrite the destination file if it exists.
- progress: whether to display a progress bar.
- leave_progress: whether to hide the progress bar once the copy is done.
- """
-
- if progress and chunk_size is None:
- chunk_size = 1024 * 1024
-
- if isinstance(source, (str, os.PathLike)):
- source_file = fsspec.open(str(source), "rb")
- else:
- source_file = source
-
- if isinstance(destination, (str, os.PathLike)):
- # adapt the file mode of the destination depending on the source file.
- destination_mode = "wb"
- if hasattr(source_file, "mode"):
- destination_mode = "wb" if "b" in getattr(source_file, "mode") else "w"
- elif isinstance(source_file, io.BytesIO):
- destination_mode = "wb"
- elif isinstance(source_file, io.StringIO):
- destination_mode = "w"
-
- destination_file = fsspec.open(str(destination), destination_mode)
- else:
- destination_file = destination
-
- if not is_file(source_file): # type: ignore
- raise ValueError(f"The file being copied does not exist or is not a file: {source}")
-
- if not force and is_file(destination_file): # type: ignore
- raise ValueError(f"The destination file to copy already exists: {destination}")
-
- with source_file as source_stream:
- with destination_file as destination_stream:
- if chunk_size is None:
- # copy without chunks
- destination_stream.write(source_stream.read()) # type: ignore
-
- else:
- # copy with chunks
-
- # determine the size of the source file
- source_size = None
- if progress:
- source_size = get_size(source)
-
- pbar = None
- if progress:
- tqdm = _import_tqdm()
-
- if tqdm is None:
- raise ImportError(
- "If the progress bar is enabled, you must have `tqdm` "
- "installed: `conda install tqdm`."
- )
- else:
- # init progress bar
- pbar = tqdm(
- total=source_size,
- leave=leave_progress,
- disable=not progress,
- unit="B",
- unit_divisor=1024,
- unit_scale=True,
- )
-
- # start the loop
- while True:
- data = source_stream.read(chunk_size) # type: ignore
- if not data:
- break
- destination_stream.write(data) # type: ignore
-
- if pbar is not None:
- pbar.update(chunk_size)
-
- if pbar is not None:
- pbar.close()
-
-
-def mkdir(dir_path: Union[str, os.PathLike], exist_ok: bool = False):
- """Create a directory.
-
- Args:
- dir_path: The path of the directory to create.
- exist_ok: Whether to ignore the error if the directory
- already exists.
- """
- fs = get_mapper(str(dir_path)).fs
- fs.mkdirs(str(dir_path), exist_ok=exist_ok)
-
-
-def glob(path: str, detail: bool = False, **kwargs) -> List[str]:
- """Find files by glob-matching.
-
- Args:
- path: A glob-style path.
- """
- # Get the list of paths
- fs = get_mapper(path).fs
- paths = fs.glob(path, detail=detail, **kwargs)
- paths = [fsspec.utils._unstrip_protocol(d, fs) for d in paths]
- return paths
-
-
-def copy_dir(
- source: Union[str, pathlib.Path],
- destination: Union[str, pathlib.Path],
- force: bool = False,
- progress: bool = False,
- leave_progress: bool = True,
- file_progress: bool = False,
- file_leave_progress: bool = False,
- chunk_size: Optional[int] = None,
-):
- """Copy one directory to another location across different filesystem (local, S3, GCS, etc).
-
- Note that if both FS from source and destination are the same, progress won't be shown.
-
- Args:
- source: Path to the source directory.
- destination: Path to the destination directory.
- chunk_size: the chunk size to use. If progress is enabled the chunk
- size is `None`, it is set to 2048.
- force: whether to overwrite the destination directory if it exists.
- progress: Whether to display a progress bar.
- leave_progress: Whether to hide the progress bar once the copy is done.
- file_progress: Whether to display a progress bar for each file.
- file_leave_progress: Whether to hide the progress bar once a file copy is done.
- chunk_size: See `po.utils.fs.copy_file`.
- """
-
- source = str(source)
- destination = str(destination)
-
- source_fs = get_mapper(source).fs
- destination_fs = get_mapper(destination).fs
-
- # Sanity check
- if not is_dir(source):
- raise ValueError(f"The directory being copied does not exist or is not a directory: {source}")
-
- if not force and is_dir(destination):
- raise ValueError(f"The destination folder to copy already exists: {destination}")
-
- # If both fs are the same then we just rely on the internal `copy` method
- # which is much faster.
- if destination_fs.__class__ == source_fs.__class__:
- source_fs.copy(source, destination, recursive=True)
- return
-
- # Get all input paths with details
- # NOTE(hadim): we could have use `.glob(..., detail=True)` here but that API is inconsistent
- # between the backends resulting in different object types being returned (dict, list, etc).
- detailed_paths = source_fs.find(source, withdirs=True, detail=True)
- detailed_paths = list(detailed_paths.values())
-
- # Get list of input types
- input_types = [d["type"] for d in detailed_paths]
-
- # Get list of input path + add protocol if needed
- input_paths = [d["name"] for d in detailed_paths]
- input_paths = [fsspec.utils._unstrip_protocol(p, source_fs) for p in input_paths]
-
- # Build all the output paths
- output_paths: List[str] = fsspec.utils.other_paths(input_paths, destination) # type: ignore
-
- def _copy_source_to_destination(input_path, input_type, output_path):
- # A directory
- if input_type == "directory":
- destination_fs.mkdir(output_path)
-
- # A file
- else:
- copy_file(
- input_path,
- output_path,
- force=force,
- progress=file_progress,
- leave_progress=file_leave_progress,
- chunk_size=chunk_size,
- )
-
- # Copy source files/directories to destination in parallel
- parallelized(
- _copy_source_to_destination,
- inputs_list=list(zip(input_paths, input_types, output_paths)),
- arg_type="args",
- progress=progress,
- tqdm_kwargs=dict(leave=leave_progress),
- scheduler="threads",
- )
-
-
-def hash_file(file: Union[str, os.PathLike, io.BytesIO, io.IOBase], chunk_size: int = 4096):
- """Return the md5 hash of a file."""
-
- md5 = hashlib.md5()
-
- if isinstance(file, (io.BytesIO, io.TextIOBase, io.BufferedIOBase)):
- sentinel = b""
- if isinstance(file, io.TextIOBase):
- sentinel = ""
-
- for block in iter(lambda: file.read(chunk_size), sentinel):
- if isinstance(block, str):
- block = block.encode()
-
- md5.update(block)
- file.seek(0)
-
- elif is_file(file):
- with fsspec.open(file, "rb") as f:
- for block in iter(lambda: f.read(chunk_size), b""):
- md5.update(block)
-
- return md5.hexdigest()
diff --git a/polaris/utils/io.py b/polaris/utils/io.py
deleted file mode 100644
index 56d49a62..00000000
--- a/polaris/utils/io.py
+++ /dev/null
@@ -1,145 +0,0 @@
-import os.path
-import uuid
-from typing import Optional
-
-import filelock
-import fsspec
-from loguru import logger
-from tenacity import Retrying
-from tenacity.stop import stop_after_attempt
-from tenacity.wait import wait_fixed
-
-from polaris.utils import fs
-from polaris.utils.constants import DEFAULT_CACHE_DIR
-from polaris.utils.errors import PolarisChecksumError
-
-
-def create_filelock(lock_name: str, cache_dir_path: str = DEFAULT_CACHE_DIR):
- """Create an empty lock file into `cache_dir_path/locks/lock_name`"""
- lock_path = fs.join(cache_dir_path, "_lock_files", lock_name)
- with fsspec.open(lock_path, "w", auto_mkdir=True):
- pass
- return filelock.FileLock(lock_path)
-
-
-def robust_copy(
- source_path: str,
- destination_path: str,
- md5sum: Optional[str] = None,
- max_retries: int = 5,
- wait_after_try: int = 2,
- progress: bool = True,
- leave_progress: bool = True,
- chunk_size: int = 2048,
-):
- if not fs.is_file(source_path) and get_zarr_root(source_path) is None:
- raise ValueError(f"{source_path} is a directory and not part of a .zarr hierarchy!")
-
- if md5sum is None and fs.is_file(source_path):
- # NOTE (cwognum): This effectively means we will not check the checksum of .zarr files.
- # The reason being that I'm not sure how to effectively compute a checksum for a .zarr
- md5sum = fs.hash_file(source_path)
-
- artifact_cache_lock = create_filelock(f"artifact_version_{md5sum or uuid.uuid4()}.lock")
-
- def log_failure(retry_state):
- logger.warning(
- f"""Downloading the artifact from {source_path} to {destination_path} failed. """
- f"""Retrying attempt {retry_state.attempt_number}/{max_retries} """
- f"""after a sleeping period of {wait_after_try} seconds."""
- )
-
- # This context manager will lock any process that try to download the same file. Only one process
- # will be able to download the artifact and all the other ones will be waiting at that line.
- # Once the lock is released the other processes will call `download_with_checksum` but the download will
- # not happen since the artifact file will already exist and its checksum will be correct.
- with artifact_cache_lock:
- # This loop will retry downloading the artifact for multiple attempts. Downloading an artifact
- # might fail for multiple reasons such as disk IO failures or network failures. The checksum logic
- # and the retry mechanism together allow to be resilient in case of intermitent failures.
- for attempt in Retrying(
- reraise=True,
- stop=stop_after_attempt(max_retries),
- after=log_failure,
- wait=wait_fixed(wait_after_try),
- ):
- with attempt:
- # The checksum logic will only validate an artifact download if its checksum matches
- # the excepted one. If not then it will be deleted and the download will happen again
- # until it succeeds (or until the number of attemps have been reached).
- download_with_checksum(
- source_path=source_path,
- destination_path=destination_path,
- md5sum=md5sum,
- progress=progress,
- leave_progress=leave_progress,
- chunk_size=chunk_size,
- )
-
- return destination_path
-
-
-def download_with_checksum(
- source_path: str,
- destination_path: str,
- md5sum: Optional[str],
- progress: bool = False,
- leave_progress: bool = True,
- chunk_size: int = 2048,
-):
- """Download an artifact from the bucket to a cache path while checking for its md5sum given a true md5sum.
-
- Args:
- source_path: The path to the artifact in the bucket.
- destination_path: The path of the artifact in the local cache.
- md5sum: The true md5sum to check against. If None, no checksum is performed but a warning is logged.
- progress: whether to display a progress bar.
- leave_progress: whether to hide the progress bar once the copy is done.
- chunk_size: the chunk size for the download.
- """
-
- # Download the artifact if not already in the cache.
- if not fs.exists(destination_path):
- if fs.is_dir(source_path):
- fs.copy_dir(
- source_path,
- destination_path,
- progress=progress,
- leave_progress=leave_progress,
- chunk_size=chunk_size,
- )
-
- else:
- fs.copy_file(
- source_path,
- destination_path,
- progress=progress,
- leave_progress=leave_progress,
- chunk_size=chunk_size,
- )
-
- # Check the cached artifact has the correct md5sum
- if md5sum is not None:
- cache_md5sum = fs.hash_file(destination_path)
- if cache_md5sum != md5sum:
- file_system = fs.get_mapper(destination_path).fs
- file_system.delete(destination_path)
-
- raise PolarisChecksumError(
- f"""The destination artifact at {destination_path} has a different md5sum ({cache_md5sum})"""
- f"""than the expected artifact md5sum ({md5sum}). The destination artifact has been deleted. """
- )
-
-
-def get_zarr_root(path):
- """
- Recursive function to find the root of a .zarr file.
- Finds the highest level directory that has the .zarr extension.
- """
- if os.path.dirname(path) == path:
- # We reached the root of the filesystem
- return
- root = get_zarr_root(os.path.dirname(path))
- if root is None and fs.get_extension(path) == "zarr":
- root = path
- return root
diff --git a/tests/conftest.py b/tests/conftest.py
index 42df3b5e..8874e473 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -2,13 +2,13 @@
import numpy as np
import pytest
import zarr
+from datamol.utils import fs
from polaris.benchmark import (
MultiTaskBenchmarkSpecification,
SingleTaskBenchmarkSpecification,
)
from polaris.dataset import ColumnAnnotation, Dataset
-from polaris.utils import fs
from polaris.utils.types import HubOwner, License
@@ -59,7 +59,7 @@ def test_dataset(test_data, test_org_owner):
table=test_data,
name="test-dataset",
source="https://www.example.com",
- annotations={"expt": ColumnAnnotation(is_pointer=False, user_attributes={"unit": "kcal/mol"})},
+ annotations={"expt": ColumnAnnotation(user_attributes={"unit": "kcal/mol"})},
tags=["tagA", "tagB"],
user_attributes={"attributeA": "valueA", "attributeB": "valueB"},
owner=test_org_owner,
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 9eab85e2..26da2eb8 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -2,11 +2,11 @@
import pandas as pd
import pytest
import zarr
+from datamol.utils import fs
from pydantic import ValidationError
from polaris.dataset import Dataset, create_dataset_from_file
from polaris.loader import load_dataset
-from polaris.utils import fs
from polaris.utils.errors import PolarisChecksumError
@@ -43,17 +43,17 @@ def test_load_data(tmp_path, with_slice, with_caching):
arr = np.random.random((100, 100))
tmpdir = str(tmp_path)
- path = fs.join(tmpdir, "data.zarr")
+ zarr_path = fs.join(tmpdir, "data.zarr")
- root = zarr.open(path, "w")
+ root = zarr.open(zarr_path, "w")
root.array("A", data=arr)
- path = f"{path}/A#0:5" if with_slice else f"{path}/A#0"
+ path = "A#0:5" if with_slice else "A#0"
table = pd.DataFrame({"A": [path]}, index=[0])
- dataset = Dataset(table=table, annotations={"A": {"is_pointer": True}})
+ dataset = Dataset(table=table, annotations={"A": {"is_pointer": True}}, zarr_root_path=zarr_path)
if with_caching:
- dataset.cache(tmpdir)
+ dataset.cache(fs.join(tmpdir, "cache"))
data = dataset.get_data(row=0, col="A")
@@ -164,8 +164,6 @@ def test_dataset_caching(zarr_archive, tmpdir):
assert original_dataset == cached_dataset
cache_dir = cached_dataset.cache(tmpdir.join("cached").strpath)
- for i in range(len(cached_dataset)):
- assert cached_dataset.table.loc[i, "A"].startswith(cache_dir)
- assert cached_dataset.table.loc[i, "B"].startswith(cache_dir)
+ assert cached_dataset.zarr_root_path.startswith(cache_dir)
assert _equality_test(cached_dataset, original_dataset)