diff --git a/polaris/_artifact.py b/polaris/_artifact.py index 318cad98..609e4fcd 100644 --- a/polaris/_artifact.py +++ b/polaris/_artifact.py @@ -37,7 +37,7 @@ class BaseArtifactModel(BaseModel): owner: A slug-compatible name for the owner of the dataset. If the dataset comes from the Polaris Hub, this is the associated owner (organization or user). Together with the name, this is used by the Hub to uniquely identify the benchmark. - version: The version of the Polaris library that was used to create the artifact. + polaris_version: The version of the Polaris library that was used to create the artifact. """ model_config = ConfigDict(alias_generator=to_camel, populate_by_name=True, arbitrary_types_allowed=True) @@ -47,14 +47,14 @@ class BaseArtifactModel(BaseModel): tags: list[str] = Field(default_factory=list) user_attributes: Dict[str, str] = Field(default_factory=dict) owner: Optional[HubOwner] = None - version: str = po.__version__ + polaris_version: str = po.__version__ @computed_field @property def artifact_id(self) -> Optional[str]: return f"{self.owner}/{sluggify(self.name)}" if self.owner and self.name else None - @field_validator("version") + @field_validator("polaris_version") @classmethod def _validate_version(cls, value: str) -> str: if value != "dev": diff --git a/polaris/dataset/_adapters.py b/polaris/dataset/_adapters.py index 6bf51ba5..b97b55af 100644 --- a/polaris/dataset/_adapters.py +++ b/polaris/dataset/_adapters.py @@ -1,8 +1,11 @@ -from enum import Enum - +from enum import Enum, auto, unique import datamol as dm +# Map of conversion operations which can be applied to dataset columns +conversion_map = {"SMILES_TO_MOL": dm.to_mol, "BYTES_TO_MOL": dm.Mol} + +@unique class Adapter(Enum): """ Adapters are predefined callables that change the format of the data. @@ -13,10 +16,10 @@ class Adapter(Enum): BYTES_TO_MOL: Convert a RDKit binary string to a RDKit molecule. """ - SMILES_TO_MOL = dm.to_mol - BYTES_TO_MOL = dm.Mol + SMILES_TO_MOL = auto() + BYTES_TO_MOL = auto() def __call__(self, data): if isinstance(data, tuple): - return tuple(self.value(d) for d in data) - return self.value(data) + return tuple(conversion_map[self.name](d) for d in data) + return conversion_map[self.name](data) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 46632531..e496e1ae 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -154,9 +154,9 @@ def _validate_model(cls, m: "Dataset"): return m - @field_validator("default_adapters") + @field_validator("default_adapters", mode="before") def _validate_adapters(cls, value): - """Serializes the adapters""" + """Validate the adapters""" return {k: Adapter[v] if isinstance(v, str) else v for k, v in value.items()} @field_serializer("default_adapters") @@ -270,11 +270,15 @@ def get_data(self, row: int, col: str, adapters: Optional[List[Adapter]] = None) the content of the referenced file is loaded to memory. """ + # Fetch adapters for dataset and a given column adapters = adapters or self.default_adapters + adapter = adapters.get(col) - # If not a pointer, we can just return here + # If not a pointer, return it here. Apply adapter if specified. value = self.table.loc[row, col] if not self.annotations[col].is_pointer: + if adapter is not None: + return adapter(value) return value # Load the data from the Zarr archive @@ -285,8 +289,7 @@ def get_data(self, row: int, col: str, adapters: Optional[List[Adapter]] = None) if isinstance(index, slice): arr = tuple(arr) - # Adapt the input - adapter = adapters.get(col) + # Adapt the input to the specified format if adapter is not None: arr = adapter(arr) diff --git a/tests/conftest.py b/tests/conftest.py index c0c633f8..c48de044 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,7 +14,7 @@ def check_version(artifact): - assert po.__version__ == artifact.version + assert po.__version__ == artifact.polaris_version @pytest.fixture(scope="module") diff --git a/tests/test_evaluate.py b/tests/test_evaluate.py index 072f2db2..5a4332c9 100644 --- a/tests/test_evaluate.py +++ b/tests/test_evaluate.py @@ -40,7 +40,7 @@ def test_result_to_json(tmpdir: str, test_user_owner: HubOwner): path = os.path.join(tmpdir, "result.json") result.to_json(path) BenchmarkResults.from_json(path) - assert po.__version__ == result.version + assert po.__version__ == result.polaris_version def test_metrics_singletask_reg(tmpdir: str, test_single_task_benchmark: SingleTaskBenchmarkSpecification): diff --git a/tests/test_type_checks.py b/tests/test_type_checks.py index 5eebc54d..78068141 100644 --- a/tests/test_type_checks.py +++ b/tests/test_type_checks.py @@ -74,6 +74,6 @@ def test_license(): def test_version(): with pytest.raises(ValidationError): - BaseArtifactModel(version="invalid") - assert BaseArtifactModel().version == po.__version__ - assert BaseArtifactModel(version="0.1.2") + BaseArtifactModel(polaris_version="invalid") + assert BaseArtifactModel().polaris_version == po.__version__ + assert BaseArtifactModel(polaris_version="0.1.2")