diff --git a/polaris/_artifact.py b/polaris/_artifact.py index 364591c9..34fc2e4e 100644 --- a/polaris/_artifact.py +++ b/polaris/_artifact.py @@ -2,7 +2,15 @@ from typing import Dict, Optional, Union import fsspec -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + computed_field, + field_serializer, + field_validator, +) from pydantic.alias_generators import to_camel from polaris.utils.types import HubOwner, SlugCompatibleStringType @@ -38,6 +46,11 @@ class BaseArtifactModel(BaseModel): owner: Optional[HubOwner] = None _verified: bool = PrivateAttr(False) + @computed_field + @property + def artifact_id(self) -> Optional[str]: + return f"{self.owner.slug}/{self.name}" if self.owner and self.name else None + @field_serializer("owner") def _serialize_owner(self, value: HubOwner) -> Union[str, None]: return self.owner.slug if self.owner else None diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 4751780b..ef7785c4 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -1,8 +1,8 @@ import json import os import ssl -import sys import webbrowser +from hashlib import md5 from io import BytesIO from typing import Callable, Optional, Union from urllib.parse import urljoin @@ -143,7 +143,7 @@ def _base_request_to_hub(self, url: str, method: str, **kwargs): f"The request to the Polaris Hub failed. See the error message below for more details:\n{response}" ) from error - # Convert the reponse to json format if the reponse contains a 'text' body + # Convert the response to json format if the response contains a 'text' body try: response = response.json() except json.JSONDecodeError: @@ -300,7 +300,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]: response = self._base_request_to_hub( url="/dataset", method="GET", params={"limit": limit, "offset": offset} ) - dataset_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response["data"]] + dataset_list = [bm["artifactId"] for bm in response["data"]] return dataset_list def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset: @@ -443,30 +443,37 @@ def upload_dataset( # 2. Upload the parquet file to the hub # TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub? - # Step 1: Upload meta-data - # Instead of directly uploading the table, we announce to the hub that we intend to upload one. - dataset_json["tableContent"] = { - "size": sys.getsizeof(dataset.table), - "fileType": "parquet", - "md5sum": dataset._compute_checksum(dataset.table), - "url": urljoin( - self.settings.hub_url, f"/storage/dataset/{dataset.owner}/{dataset.name}/table.parquet" - ), - } - dataset_json["access"] = access - url = f"/dataset/{dataset.owner}/{dataset.name}" - response = self._base_request_to_hub(url=url, method="PUT", json=dataset_json) - - # Step 2: Upload the parquet file # Write the parquet file directly to a buffer buffer = BytesIO() dataset.table.to_parquet(buffer, engine="auto") + parquet_size = len(buffer.getbuffer()) + parquet_md5 = md5(buffer.getbuffer()).hexdigest() + # Step 1: Upload meta-data + # Instead of directly uploading the table, we announce to the hub that we intend to upload one. + url = f"/dataset/{dataset.artifact_id}" + response = self._base_request_to_hub( + url=url, + method="PUT", + json={ + "tableContent": { + "size": parquet_size, + "fileType": "parquet", + "md5sum": parquet_md5, + }, + "access": access, + **dataset_json, + }, + ) + + # Step 2: Upload the parquet file # create an empty PUT request to get the table content URL from cloudflare hub_response = self.request( - url=dataset_json["tableContent"]["url"], + url=response["tableContent"]["url"], method="PUT", - headers={"Content-type": "application/vnd.apache.parquet"}, + headers={ + "Content-type": "application/vnd.apache.parquet", + }, ) if hub_response.status_code == 307: @@ -517,7 +524,7 @@ def upload_benchmark(self, benchmark: BenchmarkSpecification, access: AccessType # Get the serialized data-model # We exclude the dataset as we expect it to exist on the hub already. - benchmark_json = benchmark.model_dump(exclude=["dataset"], exclude_none=True, by_alias=True) + benchmark_json = benchmark.model_dump(exclude={"dataset"}, exclude_none=True, by_alias=True) benchmark_json["datasetArtifactId"] = f"{benchmark.dataset.owner}/{benchmark.dataset.name}" benchmark_json["access"] = access