diff --git a/polaris/_artifact.py b/polaris/_artifact.py index 6bccbf01..364591c9 100644 --- a/polaris/_artifact.py +++ b/polaris/_artifact.py @@ -1,8 +1,9 @@ import json -from typing import Dict, Optional +from typing import Dict, Optional, Union import fsspec -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator +from pydantic.alias_generators import to_camel from polaris.utils.types import HubOwner, SlugCompatibleStringType @@ -28,6 +29,8 @@ class BaseArtifactModel(BaseModel): _verified: Whether the benchmark has been verified through the Polaris Hub. """ + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) + name: Optional[SlugCompatibleStringType] = None description: str = "" tags: list[str] = Field(default_factory=list) @@ -35,6 +38,17 @@ class BaseArtifactModel(BaseModel): owner: Optional[HubOwner] = None _verified: bool = PrivateAttr(False) + @field_serializer("owner") + def _serialize_owner(self, value: HubOwner) -> Union[str, None]: + return self.owner.slug if self.owner else None + + @field_validator("owner", mode="before") + @classmethod + def _validate_owner(cls, value: Union[str, HubOwner, None]): + if isinstance(value, str): + return HubOwner(slug=value) + return value + @classmethod def from_json(cls, path: str): """Loads a benchmark from a JSON file. diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index 6ff6dc16..87908a4a 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd from pydantic import ( - ConfigDict, FieldValidationInfo, field_serializer, field_validator, @@ -22,8 +21,8 @@ from polaris.utils.context import tmp_attribute_change from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError -from polaris.utils.misc import listit, to_lower_camel -from polaris.utils.types import DataFormat, PredictionsType, SplitType, AccessType +from polaris.utils.misc import listit +from polaris.utils.types import AccessType, DataFormat, PredictionsType, SplitType ColumnsType = Union[str, list[str]] @@ -85,11 +84,6 @@ class BenchmarkSpecification(BaseArtifactModel): # Additional meta-data readme: str = "" - # Pydantic config - model_config = ConfigDict( - arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True - ) - @field_validator("dataset") def _validate_dataset(cls, v): """ diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index 70ce66cb..cf960bb7 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -2,8 +2,7 @@ from typing import Dict, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator - -from polaris.utils.misc import to_lower_camel +from pydantic.alias_generators import to_camel class Modality(enum.Enum): @@ -36,13 +35,11 @@ class ColumnAnnotation(BaseModel): description: Optional[str] = None user_attributes: Dict[str, str] = Field(default_factory=dict) - model_config = ConfigDict( - arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True - ) + model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) @field_validator("modality") def _validate_modality(cls, v): - """Tries to converts a string to the Enum""" + """Tries to convert a string to the Enum""" if isinstance(v, str): v = Modality[v.upper()] return v diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 1f8c17fe..f0e3b53c 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -10,7 +10,6 @@ import zarr from loguru import logger from pydantic import ( - ConfigDict, Field, field_validator, model_validator, @@ -24,8 +23,7 @@ from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError from polaris.utils.io import get_zarr_root, robust_copy -from polaris.utils.misc import to_lower_camel -from polaris.utils.types import HttpUrlString, License, AccessType +from polaris.utils.types import AccessType, HttpUrlString, License # Constants _SUPPORTED_TABLE_EXTENSIONS = ["parquet"] @@ -83,11 +81,6 @@ class Dataset(BaseArtifactModel): _has_been_warned: bool = False _has_been_cached: bool = False - # Pydantic config - model_config = ConfigDict( - arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True - ) - @field_validator("table") def _validate_table(cls, v): """If the table is not a dataframe yet, assume it's a path and try load it.""" diff --git a/polaris/evaluate/_results.py b/polaris/evaluate/_results.py index 448925bc..7e019177 100644 --- a/polaris/evaluate/_results.py +++ b/polaris/evaluate/_results.py @@ -4,14 +4,22 @@ from typing import ClassVar, Optional, Union import pandas as pd -from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + PrivateAttr, + computed_field, + field_serializer, + field_validator, +) +from pydantic.alias_generators import to_camel from polaris._artifact import BaseArtifactModel from polaris.evaluate import Metric from polaris.hub.settings import PolarisHubSettings from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidResultError -from polaris.utils.misc import to_lower_camel from polaris.utils.types import AccessType, HttpUrlString, HubOwner, HubUser # Define some helpful type aliases @@ -30,7 +38,7 @@ class ResultRecords(BaseModel): scores: dict[Union[Metric, str], float] # Model config - model_config = ConfigDict(alias_generator=to_lower_camel, populate_by_name=True) + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) @field_validator("scores") def validate_scores(cls, v): @@ -104,10 +112,10 @@ class BenchmarkResults(BaseArtifactModel): # Private attributes _created_at: datetime = PrivateAttr(default_factory=datetime.now) - # Model config - model_config = ConfigDict( - alias_generator=to_lower_camel, populate_by_name=True, arbitrary_types_allowed=True - ) + @computed_field + @property + def benchmark_artifact_id(self) -> str: + return f"{self.benchmark_owner}/{self.benchmark_name}" @field_validator("results") def _validate_results(cls, v): diff --git a/polaris/hub/client.py b/polaris/hub/client.py index 11859cb8..4751780b 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -16,7 +16,7 @@ from authlib.integrations.httpx_client import OAuth2Client, OAuthError from authlib.oauth2.client import OAuth2Client as _OAuth2Client from httpx import HTTPStatusError -from httpx._types import HeaderTypes, URLTypes, TimeoutTypes +from httpx._types import HeaderTypes, TimeoutTypes, URLTypes from loguru import logger from polaris.benchmark import ( @@ -30,7 +30,7 @@ from polaris.utils import fs from polaris.utils.constants import DEFAULT_CACHE_DIR from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError -from polaris.utils.types import HubOwner, AccessType +from polaris.utils.types import AccessType, HubOwner _HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]" @@ -399,11 +399,11 @@ def upload_results(self, results: BenchmarkResults, access: AccessType = "privat # Get the serialized model data-structure result_json = results.model_dump(by_alias=True, exclude_none=True) - result_json["access"] = access # Make a request to the hub - url = f"/benchmark/{results.benchmark_owner}/{results.benchmark_name}/result" - response = self._base_request_to_hub(url=url, method="POST", json=result_json) + response = self._base_request_to_hub( + url="/result", method="POST", json={"access": access, **result_json} + ) # Inform the user about where to find their newly created artifact. result_url = urljoin( diff --git a/polaris/utils/misc.py b/polaris/utils/misc.py index 0292835a..dc2a720e 100644 --- a/polaris/utils/misc.py +++ b/polaris/utils/misc.py @@ -7,9 +7,3 @@ def listit(t: Any): https://stackoverflow.com/questions/1014352/how-do-i-convert-a-nested-tuple-of-tuples-and-lists-to-lists-of-lists-in-python """ return list(map(listit, t)) if isinstance(t, (list, tuple)) else t - - -def to_lower_camel(name: str) -> str: - """Converts a snake_case string to lowerCamelCase""" - upper = "".join(word.capitalize() for word in name.split("_")) - return upper[:1].lower() + upper[1:] diff --git a/polaris/utils/types.py b/polaris/utils/types.py index 193c8770..2186eff6 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -9,14 +9,12 @@ BaseModel, ConfigDict, HttpUrl, - computed_field, constr, model_validator, ) +from pydantic.alias_generators import to_camel from typing_extensions import TypeAlias -from polaris.utils.misc import to_lower_camel - SplitIndicesType: TypeAlias = list[int] """ A split is defined by a sequence of integers. @@ -51,9 +49,14 @@ The target formats that are supported by the `Subset` class. """ -SlugCompatibleStringType: TypeAlias = constr(pattern="^[A-Za-z0-9_-]+$", min_length=4, max_length=64) +SlugStringType: TypeAlias = constr(pattern="^[a-z0-9-]+$", min_length=4, max_length=64) """ A URL-compatible string that can serve as slug on the hub. +""" + +SlugCompatibleStringType: TypeAlias = constr(pattern="^[A-Za-z0-9_-]+$", min_length=4, max_length=64) +""" +A URL-compatible string that can be turned into a slug by the hub. Can only use alpha-numeric characters, underscores and dashes. The string must be at least 4 and at most 64 characters long. @@ -87,38 +90,16 @@ class HubOwner(BaseModel): """An owner of an artifact on the Polaris Hub - The slug is most important as it is the user-facing part of this data model. The organization - and user id are added to be consistent with the Polaris Hub. - - The username is specified as a [`SlugCompatibleStringType`][polaris.utils.types.SlugCompatibleStringType], - whereas the organization is specified as a string that can contain only alpha-numeric characters, - underscores and dashes. Contrary to the username, an organization name can currently be of arbitrary length. + The slug is most important as it is the user-facing part of this data model. + The externalId and type are added to be consistent with the model returned by the Polaris Hub . """ - slug: constr(pattern="^[A-Za-z0-9_-]+$") - organization_id: Optional[constr(pattern="^[A-Za-z0-9_-]+$")] = None - user_id: Optional[HubUser] = None + slug: SlugStringType + external_id: Optional[str] = None + type: Optional[Literal["user", "organization"]] = None # Pydantic config - model_config = ConfigDict(alias_generator=to_lower_camel, populate_by_name=True) - - @model_validator(mode="after") # type: ignore - @classmethod - def _validate_model(cls, m: "HubOwner"): - if m.organization_id is not None and m.user_id is not None: - raise ValueError("An owner cannot both have an `organization_id` and a `user_id`") - return m - - @computed_field - @property - def owner(self) -> str: - return self.organization_id or self.user_id # type: ignore - - def __str__(self) -> str: - return self.slug - - def __repr__(self) -> str: - return self.__str__() + model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True) class License(BaseModel): diff --git a/tests/test_type_checks.py b/tests/test_type_checks.py index 40042efe..a9471a89 100644 --- a/tests/test_type_checks.py +++ b/tests/test_type_checks.py @@ -4,36 +4,36 @@ from polaris.utils.types import HubOwner, License +def test_slug_string_type(): + """ + Verifies that the slug is validated correctly. + Fails if: + - Is too short (<4 characters) + - Is too long (>64 characters) + - Contains something other than lowercase letters, numbers, and hyphens. + """ + for name in ["", "x", "xx", "xxx", "x" * 65, "invalid@", "invalid!", "InvalidName1", "invalid_name"]: + with pytest.raises(ValueError): + HubOwner(slug=name) + + for name in ["valid", "valid-name-1", "x" * 64, "x" * 4]: + HubOwner(slug=name) + + def test_slug_compatible_string_type(): """Verifies that the artifact name is validated correctly.""" # Fails if: # - Is too short (<4 characters) # - Is too long (>64 characters) - # - Contains non alpha-numeric characters + # - Contains non-alphanumeric characters for name in ["", "x", "xx", "xxx", "x" * 65, "invalid@", "invalid!"]: with pytest.raises(ValueError): BaseArtifactModel(name=name) - with pytest.raises(ValueError): - HubOwner(userId=name, slug=name) # Does not fail for name in ["valid", "valid-name", "valid_name", "ValidName1", "Valid_", "Valid-", "x" * 64, "x" * 4]: BaseArtifactModel(name=name) - HubOwner(userId=name, slug=name) - - -def test_artifact_owner(): - with pytest.raises(ValueError): - # No owner specified - HubOwner() - with pytest.raises(ValueError): - # Conflicting owner specified - HubOwner(organizationId="org", userId="user", slug="test") - - # Valid - Only specifies one! - assert HubOwner(organizationId="org", slug="org").owner == "org" - assert HubOwner(userId="user", slug="user").owner == "user" def test_license():