Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion polaris/_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from typing import Dict, Optional, Union

import fsspec
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
computed_field,
field_serializer,
field_validator,
)
from pydantic.alias_generators import to_camel

from polaris.utils.types import HubOwner, SlugCompatibleStringType
Expand Down Expand Up @@ -38,6 +46,11 @@ class BaseArtifactModel(BaseModel):
owner: Optional[HubOwner] = None
_verified: bool = PrivateAttr(False)

@computed_field
@property
def artifact_id(self) -> Optional[str]:
return f"{self.owner.slug}/{self.name}" if self.owner and self.name else None

@field_serializer("owner")
def _serialize_owner(self, value: HubOwner) -> Union[str, None]:
return self.owner.slug if self.owner else None
Expand Down
49 changes: 28 additions & 21 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os
import ssl
import sys
import webbrowser
from hashlib import md5
from io import BytesIO
from typing import Callable, Optional, Union
from urllib.parse import urljoin
Expand Down Expand Up @@ -143,7 +143,7 @@ def _base_request_to_hub(self, url: str, method: str, **kwargs):
f"The request to the Polaris Hub failed. See the error message below for more details:\n{response}"
) from error

# Convert the reponse to json format if the reponse contains a 'text' body
# Convert the response to json format if the response contains a 'text' body
try:
response = response.json()
except json.JSONDecodeError:
Expand Down Expand Up @@ -300,7 +300,7 @@ def list_datasets(self, limit: int = 100, offset: int = 0) -> list[str]:
response = self._base_request_to_hub(
url="/dataset", method="GET", params={"limit": limit, "offset": offset}
)
dataset_list = [f"{HubOwner(**bm['owner'])}/{bm['name']}" for bm in response["data"]]
dataset_list = [bm["artifactId"] for bm in response["data"]]
return dataset_list

def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset:
Expand Down Expand Up @@ -443,30 +443,37 @@ def upload_dataset(
# 2. Upload the parquet file to the hub
# TODO: Revert step 1 in case step 2 fails - Is this needed? Or should this be taken care of by the hub?

# Step 1: Upload meta-data
# Instead of directly uploading the table, we announce to the hub that we intend to upload one.
dataset_json["tableContent"] = {
"size": sys.getsizeof(dataset.table),
"fileType": "parquet",
"md5sum": dataset._compute_checksum(dataset.table),
"url": urljoin(
self.settings.hub_url, f"/storage/dataset/{dataset.owner}/{dataset.name}/table.parquet"
),
}
dataset_json["access"] = access
url = f"/dataset/{dataset.owner}/{dataset.name}"
response = self._base_request_to_hub(url=url, method="PUT", json=dataset_json)

# Step 2: Upload the parquet file
# Write the parquet file directly to a buffer
buffer = BytesIO()
dataset.table.to_parquet(buffer, engine="auto")
parquet_size = len(buffer.getbuffer())
parquet_md5 = md5(buffer.getbuffer()).hexdigest()

# Step 1: Upload meta-data
# Instead of directly uploading the table, we announce to the hub that we intend to upload one.
url = f"/dataset/{dataset.artifact_id}"
response = self._base_request_to_hub(
url=url,
method="PUT",
json={
"tableContent": {
"size": parquet_size,
"fileType": "parquet",
"md5sum": parquet_md5,
},
"access": access,
**dataset_json,
},
)

# Step 2: Upload the parquet file
# create an empty PUT request to get the table content URL from cloudflare
hub_response = self.request(
url=dataset_json["tableContent"]["url"],
url=response["tableContent"]["url"],
method="PUT",
headers={"Content-type": "application/vnd.apache.parquet"},
headers={
"Content-type": "application/vnd.apache.parquet",
},
)

if hub_response.status_code == 307:
Expand Down Expand Up @@ -517,7 +524,7 @@ def upload_benchmark(self, benchmark: BenchmarkSpecification, access: AccessType

# Get the serialized data-model
# We exclude the dataset as we expect it to exist on the hub already.
benchmark_json = benchmark.model_dump(exclude=["dataset"], exclude_none=True, by_alias=True)
benchmark_json = benchmark.model_dump(exclude={"dataset"}, exclude_none=True, by_alias=True)
benchmark_json["datasetArtifactId"] = f"{benchmark.dataset.owner}/{benchmark.dataset.name}"
benchmark_json["access"] = access

Expand Down