Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions polaris/_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BaseArtifactModel(BaseModel):
owner: A slug-compatible name for the owner of the dataset.
If the dataset comes from the Polaris Hub, this is the associated owner (organization or user).
Together with the name, this is used by the Hub to uniquely identify the benchmark.
version: The version of the Polaris library that was used to create the artifact.
polaris_version: The version of the Polaris library that was used to create the artifact.
"""

model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True)
Expand All @@ -47,14 +47,14 @@ class BaseArtifactModel(BaseModel):
tags: list[str] = Field(default_factory=list)
user_attributes: Dict[str, str] = Field(default_factory=dict)
owner: Optional[HubOwner] = None
version: str = po.__version__
polaris_version: str = po.__version__

@computed_field
@property
def artifact_id(self) -> Optional[str]:
return f"{self.owner}/{sluggify(self.name)}" if self.owner and self.name else None

@field_validator("version")
@field_validator("polaris_version")
@classmethod
def _validate_version(cls, value: str) -> str:
if value != "dev":
Expand Down
15 changes: 9 additions & 6 deletions polaris/dataset/_adapters.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from enum import Enum

from enum import Enum, auto, unique
import datamol as dm

# Map of conversion operations which can be applied to dataset columns
conversion_map = {"SMILES_TO_MOL": dm.to_mol, "BYTES_TO_MOL": dm.Mol}


@unique
class Adapter(Enum):
"""
Adapters are predefined callables that change the format of the data.
Expand All @@ -13,10 +16,10 @@ class Adapter(Enum):
BYTES_TO_MOL: Convert a RDKit binary string to a RDKit molecule.
"""

SMILES_TO_MOL = dm.to_mol
BYTES_TO_MOL = dm.Mol
SMILES_TO_MOL = auto()
BYTES_TO_MOL = auto()

def __call__(self, data):
if isinstance(data, tuple):
return tuple(self.value(d) for d in data)
return self.value(data)
return tuple(conversion_map[self.name](d) for d in data)
return conversion_map[self.name](data)
13 changes: 8 additions & 5 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ def _validate_model(cls, m: "Dataset"):

return m

@field_validator("default_adapters")
@field_validator("default_adapters", mode="before")
def _validate_adapters(cls, value):
"""Serializes the adapters"""
"""Validate the adapters"""
return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()}

@field_serializer("default_adapters")
Expand Down Expand Up @@ -270,11 +270,15 @@ def get_data(self, row: int, col: str, adapters: Optional[List[Adapter]] = None)
the content of the referenced file is loaded to memory.
"""

# Fetch adapters for dataset and a given column
adapters = adapters or self.default_adapters
adapter = adapters.get(col)

# If not a pointer, we can just return here
# If not a pointer, return it here. Apply adapter if specified.
value = self.table.loc[row, col]
if not self.annotations[col].is_pointer:
if adapter is not None:
return adapter(value)
return value

# Load the data from the Zarr archive
Expand All @@ -285,8 +289,7 @@ def get_data(self, row: int, col: str, adapters: Optional[List[Adapter]] = None)
if isinstance(index, slice):
arr = tuple(arr)

# Adapt the input
adapter = adapters.get(col)
# Adapt the input to the specified format
if adapter is not None:
arr = adapter(arr)

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def check_version(artifact):
assert po.__version__ == artifact.version
assert po.__version__ == artifact.polaris_version


@pytest.fixture(scope="module")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_result_to_json(tmpdir: str, test_user_owner: HubOwner):
path = os.path.join(tmpdir, "result.json")
result.to_json(path)
BenchmarkResults.from_json(path)
assert po.__version__ == result.version
assert po.__version__ == result.polaris_version


def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleTaskBenchmarkSpecification):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_type_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ def test_license():

def test_version():
with pytest.raises(ValidationError):
BaseArtifactModel(version="invalid")
assert BaseArtifactModel().version == po.__version__
assert BaseArtifactModel(version="0.1.2")
BaseArtifactModel(polaris_version="invalid")
assert BaseArtifactModel().polaris_version == po.__version__
assert BaseArtifactModel(polaris_version="0.1.2")