Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions polaris/_artifact.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import json
from typing import Dict, Optional
from typing import Dict, Optional, Union

import fsspec
from pydantic import BaseModel, Field, PrivateAttr
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator
from pydantic.alias_generators import to_camel

from polaris.utils.types import HubOwner, SlugCompatibleStringType

Expand All @@ -28,13 +29,26 @@ class BaseArtifactModel(BaseModel):
_verified: Whether the benchmark has been verified through the Polaris Hub.
"""

model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True)

name: Optional[SlugCompatibleStringType] = None
description: str = ""
tags: list[str] = Field(default_factory=list)
user_attributes: Dict[str, str] = Field(default_factory=dict)
owner: Optional[HubOwner] = None
_verified: bool = PrivateAttr(False)

@field_serializer("owner")
def _serialize_owner(self, value: HubOwner) -> Union[str, None]:
return self.owner.slug if self.owner else None

@field_validator("owner", mode="before")
@classmethod
def _validate_owner(cls, value: Union[str, HubOwner, None]):
if isinstance(value, str):
return HubOwner(slug=value)
return value

@classmethod
def from_json(cls, path: str):
"""Loads a benchmark from a JSON file.
Expand Down
10 changes: 2 additions & 8 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import numpy as np
import pandas as pd
from pydantic import (
ConfigDict,
FieldValidationInfo,
field_serializer,
field_validator,
Expand All @@ -22,8 +21,8 @@
from polaris.utils.context import tmp_attribute_change
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError
from polaris.utils.misc import listit, to_lower_camel
from polaris.utils.types import DataFormat, PredictionsType, SplitType, AccessType
from polaris.utils.misc import listit
from polaris.utils.types import AccessType, DataFormat, PredictionsType, SplitType

ColumnsType = Union[str, list[str]]

Expand Down Expand Up @@ -85,11 +84,6 @@ class BenchmarkSpecification(BaseArtifactModel):
# Additional meta-data
readme: str = ""

# Pydantic config
model_config = ConfigDict(
arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True
)

@field_validator("dataset")
def _validate_dataset(cls, v):
"""
Expand Down
9 changes: 3 additions & 6 deletions polaris/dataset/_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from typing import Dict, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator

from polaris.utils.misc import to_lower_camel
from pydantic.alias_generators import to_camel


class Modality(enum.Enum):
Expand Down Expand Up @@ -36,13 +35,11 @@ class ColumnAnnotation(BaseModel):
description: Optional[str] = None
user_attributes: Dict[str, str] = Field(default_factory=dict)

model_config = ConfigDict(
arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True
)
model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True)

@field_validator("modality")
def _validate_modality(cls, v):
"""Tries to converts a string to the Enum"""
"""Tries to convert a string to the Enum"""
if isinstance(v, str):
v = Modality[v.upper()]
return v
Expand Down
9 changes: 1 addition & 8 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import zarr
from loguru import logger
from pydantic import (
ConfigDict,
Field,
field_validator,
model_validator,
Expand All @@ -24,8 +23,7 @@
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError
from polaris.utils.io import get_zarr_root, robust_copy
from polaris.utils.misc import to_lower_camel
from polaris.utils.types import HttpUrlString, License, AccessType
from polaris.utils.types import AccessType, HttpUrlString, License

# Constants
_SUPPORTED_TABLE_EXTENSIONS = ["parquet"]
Expand Down Expand Up @@ -83,11 +81,6 @@ class Dataset(BaseArtifactModel):
_has_been_warned: bool = False
_has_been_cached: bool = False

# Pydantic config
model_config = ConfigDict(
arbitrary_types_allowed=True, alias_generator=to_lower_camel, populate_by_name=True
)

@field_validator("table")
def _validate_table(cls, v):
"""If the table is not a dataframe yet, assume it's a path and try load it."""
Expand Down
22 changes: 15 additions & 7 deletions polaris/evaluate/_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,22 @@
from typing import ClassVar, Optional, Union

import pandas as pd
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_serializer, field_validator
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
computed_field,
field_serializer,
field_validator,
)
from pydantic.alias_generators import to_camel

from polaris._artifact import BaseArtifactModel
from polaris.evaluate import Metric
from polaris.hub.settings import PolarisHubSettings
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidResultError
from polaris.utils.misc import to_lower_camel
from polaris.utils.types import AccessType, HttpUrlString, HubOwner, HubUser

# Define some helpful type aliases
Expand All @@ -30,7 +38,7 @@ class ResultRecords(BaseModel):
scores: dict[Union[Metric, str], float]

# Model config
model_config = ConfigDict(alias_generator=to_lower_camel, populate_by_name=True)
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)

@field_validator("scores")
def validate_scores(cls, v):
Expand Down Expand Up @@ -104,10 +112,10 @@ class BenchmarkResults(BaseArtifactModel):
# Private attributes
_created_at: datetime = PrivateAttr(default_factory=datetime.now)

# Model config
model_config = ConfigDict(
alias_generator=to_lower_camel, populate_by_name=True, arbitrary_types_allowed=True
)
@computed_field
@property
def benchmark_artifact_id(self) -> str:
return f"{self.benchmark_owner}/{self.benchmark_name}"

@field_validator("results")
def _validate_results(cls, v):
Expand Down
10 changes: 5 additions & 5 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from authlib.integrations.httpx_client import OAuth2Client, OAuthError
from authlib.oauth2.client import OAuth2Client as _OAuth2Client
from httpx import HTTPStatusError
from httpx._types import HeaderTypes, URLTypes, TimeoutTypes
from httpx._types import HeaderTypes, TimeoutTypes, URLTypes
from loguru import logger

from polaris.benchmark import (
Expand All @@ -30,7 +30,7 @@
from polaris.utils import fs
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError
from polaris.utils.types import HubOwner, AccessType
from polaris.utils.types import AccessType, HubOwner

_HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]"

Expand Down Expand Up @@ -399,11 +399,11 @@ def upload_results(self, results: BenchmarkResults, access: AccessType = "privat

# Get the serialized model data-structure
result_json = results.model_dump(by_alias=True, exclude_none=True)
result_json["access"] = access

# Make a request to the hub
url = f"/benchmark/{results.benchmark_owner}/{results.benchmark_name}/result"
response = self._base_request_to_hub(url=url, method="POST", json=result_json)
response = self._base_request_to_hub(
url="/result", method="POST", json={"access": access, **result_json}
)

# Inform the user about where to find their newly created artifact.
result_url = urljoin(
Expand Down
6 changes: 0 additions & 6 deletions polaris/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,3 @@ def listit(t: Any):
https://stackoverflow.com/questions/1014352/how-do-i-convert-a-nested-tuple-of-tuples-and-lists-to-lists-of-lists-in-python
"""
return list(map(listit, t)) if isinstance(t, (list, tuple)) else t


def to_lower_camel(name: str) -> str:
"""Converts a snake_case string to lowerCamelCase"""
upper = "".join(word.capitalize() for word in name.split("_"))
return upper[:1].lower() + upper[1:]
45 changes: 13 additions & 32 deletions polaris/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@
BaseModel,
ConfigDict,
HttpUrl,
computed_field,
constr,
model_validator,
)
from pydantic.alias_generators import to_camel
from typing_extensions import TypeAlias

from polaris.utils.misc import to_lower_camel

SplitIndicesType: TypeAlias = list[int]
"""
A split is defined by a sequence of integers.
Expand Down Expand Up @@ -51,9 +49,14 @@
The target formats that are supported by the `Subset` class.
"""

SlugCompatibleStringType: TypeAlias = constr(pattern="^[A-Za-z0-9_-]+$", min_length=4, max_length=64)
SlugStringType: TypeAlias = constr(pattern="^[a-z0-9-]+$", min_length=4, max_length=64)
"""
A URL-compatible string that can serve as slug on the hub.
"""

SlugCompatibleStringType: TypeAlias = constr(pattern="^[A-Za-z0-9_-]+$", min_length=4, max_length=64)
"""
A URL-compatible string that can be turned into a slug by the hub.

Can only use alpha-numeric characters, underscores and dashes.
The string must be at least 4 and at most 64 characters long.
Expand Down Expand Up @@ -87,38 +90,16 @@
class HubOwner(BaseModel):
"""An owner of an artifact on the Polaris Hub

The slug is most important as it is the user-facing part of this data model. The organization
and user id are added to be consistent with the Polaris Hub.

The username is specified as a [`SlugCompatibleStringType`][polaris.utils.types.SlugCompatibleStringType],
whereas the organization is specified as a string that can contain only alpha-numeric characters,
underscores and dashes. Contrary to the username, an organization name can currently be of arbitrary length.
The slug is most important as it is the user-facing part of this data model.
The externalId and type are added to be consistent with the model returned by the Polaris Hub .
"""

slug: constr(pattern="^[A-Za-z0-9_-]+$")
organization_id: Optional[constr(pattern="^[A-Za-z0-9_-]+$")] = None
user_id: Optional[HubUser] = None
slug: SlugStringType
external_id: Optional[str] = None
type: Optional[Literal["user", "organization"]] = None

# Pydantic config
model_config = ConfigDict(alias_generator=to_lower_camel, populate_by_name=True)

@model_validator(mode="after") # type: ignore
@classmethod
def _validate_model(cls, m: "HubOwner"):
if m.organization_id is not None and m.user_id is not None:
raise ValueError("An owner cannot both have an `organization_id` and a `user_id`")
return m

@computed_field
@property
def owner(self) -> str:
return self.organization_id or self.user_id # type: ignore

def __str__(self) -> str:
return self.slug

def __repr__(self) -> str:
return self.__str__()
model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True)


class License(BaseModel):
Expand Down
34 changes: 17 additions & 17 deletions tests/test_type_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@
from polaris.utils.types import HubOwner, License


def test_slug_string_type():
"""
Verifies that the slug is validated correctly.
Fails if:
- Is too short (<4 characters)
- Is too long (>64 characters)
- Contains something other than lowercase letters, numbers, and hyphens.
"""
for name in ["", "x", "xx", "xxx", "x" * 65, "invalid@", "invalid!", "InvalidName1", "invalid_name"]:
with pytest.raises(ValueError):
HubOwner(slug=name)

for name in ["valid", "valid-name-1", "x" * 64, "x" * 4]:
HubOwner(slug=name)


def test_slug_compatible_string_type():
"""Verifies that the artifact name is validated correctly."""

# Fails if:
# - Is too short (<4 characters)
# - Is too long (>64 characters)
# - Contains non alpha-numeric characters
# - Contains non-alphanumeric characters
for name in ["", "x", "xx", "xxx", "x" * 65, "invalid@", "invalid!"]:
with pytest.raises(ValueError):
BaseArtifactModel(name=name)
with pytest.raises(ValueError):
HubOwner(userId=name, slug=name)

# Does not fail
for name in ["valid", "valid-name", "valid_name", "ValidName1", "Valid_", "Valid-", "x" * 64, "x" * 4]:
BaseArtifactModel(name=name)
HubOwner(userId=name, slug=name)


def test_artifact_owner():
with pytest.raises(ValueError):
# No owner specified
HubOwner()
with pytest.raises(ValueError):
# Conflicting owner specified
HubOwner(organizationId="org", userId="user", slug="test")

# Valid - Only specifies one!
assert HubOwner(organizationId="org", slug="org").owner == "org"
assert HubOwner(userId="user", slug="user").owner == "user"


def test_license():
Expand Down