From cceb2fb12c22e318e9ee81071a4c1f4f0133f03f Mon Sep 17 00:00:00 2001 From: cwognum Date: Mon, 4 Dec 2023 14:07:33 -0500 Subject: [PATCH 1/6] Add additional meta-data to the columns, dataset and benchmark to be visualized in the Hub --- polaris/benchmark/_base.py | 110 +++++++++++++++++++++++++++--- polaris/benchmark/_definitions.py | 15 +++- polaris/dataset/_column.py | 17 +++++ polaris/dataset/_dataset.py | 49 ++++++++++++- polaris/utils/types.py | 15 ++++ tests/conftest.py | 2 + 6 files changed, 198 insertions(+), 10 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index c1318316..b40608c8 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -7,11 +7,14 @@ import numpy as np import pandas as pd from pydantic import ( + Field, FieldValidationInfo, + computed_field, field_serializer, field_validator, model_validator, ) +from sklearn.utils.multiclass import type_of_target from polaris._artifact import BaseArtifactModel from polaris.dataset import Dataset, Subset @@ -22,7 +25,15 @@ from polaris.utils.dict2html import dict2html from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError from polaris.utils.misc import listit -from polaris.utils.types import AccessType, DataFormat, HubOwner, PredictionsType, SplitType +from polaris.utils.types import ( + AccessType, + DataFormat, + HubOwner, + PredictionsType, + SplitType, + TargetType, + TaskType, +) ColumnsType = Union[str, list[str]] @@ -47,13 +58,26 @@ class BenchmarkSpecification(BaseArtifactModel): ```python import polaris as po - benchmark = po.load_benchmark("/path/to/benchmark") + # Load the benchmark from the Hub + benchmark = po.load_benchmark("polaris/hello_world_benchmark") + + # Get the train and test data-loaders train, test = benchmark.get_train_test_split() - # Work your magic - predictions = ... + # Use the training data to train your model + # Get the input as an array with 'train.inputs' and 'train.targets' + # Or simply iterate over the train object. + for x, y in train: + ... + + # Work your magic to accurately predict the test set + predictions = [0.0 for x in test] - benchmark.evaluate(predictions) + # Evaluate your predictions + results = benchmark.evaluate(predictions) + + # Submit your results + results.upload_to_hub(owner="dummy-user") ``` Attributes: @@ -68,6 +92,7 @@ class BenchmarkSpecification(BaseArtifactModel): readme: Markdown text that can be used to provide a formatted description of the benchmark. If using the Polaris Hub, it is worth noting that this field is more easily edited through the Hub UI as it provides a rich text editor for writing markdown. + target_types: A dictionary that maps target columns to their type. If not specified, this is automatically inferred. For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. """ @@ -83,6 +108,9 @@ class BenchmarkSpecification(BaseArtifactModel): # Additional meta-data readme: str = "" + target_types: dict[str, Optional[Union[TargetType, str]]] = Field( + default_factory=dict, validate_default=True + ) @field_validator("dataset") def _validate_dataset(cls, v): @@ -175,9 +203,31 @@ def _validate_split(cls, v, info: FieldValidationInfo): raise InvalidBenchmarkError("The predefined split contains invalid indices") return v + @field_validator("target_types") + def _validate_target_types(cls, v, info: FieldValidationInfo): + """Try to automatically infer the target types if not already set""" + + dataset = info.data.get("dataset") + target_cols = info.data.get("target_cols") + if dataset is None or target_cols is None: + return v + + for target in target_cols: + if target not in v: + target_type = type_of_target(dataset[:, target]) + if target_type == "continuous": + v[target] = TargetType.REGRESSION + elif target_type in ["binary", "multiclass"]: + v[target] = TargetType.CLASSIFICATION + else: + v[target] = None + elif not isinstance(v, TargetType): + v[target] = TargetType(v[target]) + return v + @model_validator(mode="after") @classmethod - def _validate_checksum(cls, m: "BenchmarkSpecification"): + def _validate_model(cls, m: "BenchmarkSpecification"): """ If a checksum is provided, verify it matches what the checksum should be. If no checksum is provided, make sure it is set. @@ -221,6 +271,11 @@ def _serialize_split(self, v): """Convert any tuple to list to make sure it's serializable""" return listit(v) + @field_serializer("target_types") + def _serialize_target_types(self, v): + """Convert from enum to string to make sure it's serializable""" + return {k: v.value for k, v in self.target_types.items()} + @staticmethod def _compute_checksum(dataset, target_cols, input_cols, split, metrics): """ @@ -254,6 +309,47 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): checksum = hash_fn.hexdigest() return checksum + @computed_field + @property + def no_train_datapoints(self) -> int: + """The size of the train set.""" + return len(self.split[0]) + + @computed_field + @property + def no_test_sets(self) -> int: + """The number of test sets""" + return len(self.split[1]) if isinstance(self.split[1], dict) else 1 + + @computed_field + @property + def no_test_datapoints(self) -> dict[str, int]: + """The size of (each of) the test set(s).""" + if self.no_test_sets == 1: + return {"test": len(self.split[1])} + else: + return {k: len(v) for k, v in self.split[1].items()} + + @computed_field + @property + def no_classes(self) -> dict[str, int]: + """The number of classes for each of the target columns.""" + no_classes = {} + for target in self.target_cols: + target_type = self.target_types[target] + if target_type is None or target_type == TargetType.REGRESSION: + no_classes[target] = None + else: + no_classes[target] = self.dataset.loc[:, target].nunique() + return no_classes + + @computed_field + @property + def task_type(self) -> TaskType: + """The high-level task type of the benchmark.""" + v = TaskType.MULTI_TASK if len(self.target_cols) > 1 else TaskType.SINGLE_TASK + return v.value + def get_train_test_split( self, input_format: DataFormat = "dict", target_format: DataFormat = "dict" ) -> tuple[Subset, Union["Subset", dict[str, Subset]]]: @@ -417,8 +513,6 @@ def _repr_dict_(self) -> dict: repr_dict.pop("dataset") repr_dict.pop("split") repr_dict["dataset_name"] = self.dataset.name - repr_dict["n_input_cols"] = len(self.input_cols) - repr_dict["n_target_cols"] = len(self.target_cols) return repr_dict def _repr_html_(self): diff --git a/polaris/benchmark/_definitions.py b/polaris/benchmark/_definitions.py index 430f5bdf..e2caa9f1 100644 --- a/polaris/benchmark/_definitions.py +++ b/polaris/benchmark/_definitions.py @@ -1,6 +1,7 @@ -from pydantic import field_validator +from pydantic import computed_field, field_validator from polaris.benchmark._base import BenchmarkSpecification +from polaris.utils.types import TaskType class SingleTaskBenchmarkSpecification(BenchmarkSpecification): @@ -16,6 +17,12 @@ def validate_target_cols(cls, v): raise ValueError("A single-task benchmark should specify a single target column") return v + @computed_field + @property + def task_type(self) -> TaskType: + """The high-level task type of the benchmark.""" + return TaskType.SINGLE_TASK.value + class MultiTaskBenchmarkSpecification(BenchmarkSpecification): """Subclass for any multi-task benchmark specification @@ -29,3 +36,9 @@ def validate_target_cols(cls, v): if not len(v) > 1: raise ValueError("A multi-task benchmark should specify at least two target columns") return v + + @computed_field + @property + def task_type(self) -> TaskType: + """The high-level task type of the benchmark.""" + return TaskType.MULTI_TASK.value diff --git a/polaris/dataset/_column.py b/polaris/dataset/_column.py index cf960bb7..0474d930 100644 --- a/polaris/dataset/_column.py +++ b/polaris/dataset/_column.py @@ -1,6 +1,8 @@ import enum from typing import Dict, Optional, Union +import numpy as np +from numpy.typing import DTypeLike from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator from pydantic.alias_generators import to_camel @@ -34,6 +36,7 @@ class ColumnAnnotation(BaseModel): modality: Union[str, Modality] = Modality.UNKNOWN description: Optional[str] = None user_attributes: Dict[str, str] = Field(default_factory=dict) + dtype: Optional[Union[np.dtype, str]] = None model_config = ConfigDict(arbitrary_types_allowed=True, alias_generator=to_camel, populate_by_name=True) @@ -44,7 +47,21 @@ def _validate_modality(cls, v): v = Modality[v.upper()] return v + @field_validator("dtype") + def _validate_dtype(cls, v): + """Tries to convert a string to the Enum""" + if isinstance(v, str): + v = np.dtype(v) + return v + @field_serializer("modality") def _serialize_modality(self, v: Modality): """Return the modality as a string, keeping it serializable""" return v.name + + @field_serializer("dtype") + def _serialize_dtype(self, v: Optional[DTypeLike]): + """Return the modality as a string, keeping it serializable""" + if v is not None: + v = v.name + return v diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index f5f658d7..346019a0 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -11,6 +11,7 @@ from loguru import logger from pydantic import ( Field, + computed_field, field_validator, model_validator, ) @@ -55,6 +56,8 @@ class Dataset(BaseArtifactModel): annotations: Each column _can be_ annotated with a [`ColumnAnnotation`][polaris.dataset.ColumnAnnotation] object. Importantly, this is used to annotate whether a column is a pointer column. source: The data source, e.g. a DOI, Github repo or URI. + license: The dataset license + curation_reference: A reference to the curation process, e.g. a DOI, Github repo or URI. For additional meta-data attributes, see the [`BaseArtifactModel`][polaris._artifact.BaseArtifactModel] class. Raises: @@ -72,6 +75,7 @@ class Dataset(BaseArtifactModel): annotations: Dict[str, ColumnAnnotation] = Field(default_factory=dict) source: Optional[HttpUrlString] = None license: Optional[License] = None + curation_reference: Optional[HttpUrlString] = None # Config cache_dir: Optional[str] = None # Where to cache the data to if cache() is called. @@ -106,6 +110,7 @@ def _validate_model(cls, m: "Dataset"): for c in m.table.columns: if c not in m.annotations: m.annotations[c] = ColumnAnnotation() + m.annotations[c].dtype = m.table[c].dtype # Verify the checksum # NOTE (cwognum): Is it still reasonable to always verify this as the dataset size grows? @@ -152,6 +157,28 @@ def _compute_checksum(table): checksum = hash_fn.hexdigest() return checksum + @computed_field + @property + def no_rows(self) -> int: + """The number of datapoints in the dataset.""" + return len(self.rows) + + @computed_field + @property + def no_columns(self) -> int: + """The number of columns in the dataset.""" + return len(self.columns) + + @property + def rows(self) -> list: + """Return all row indices for the dataset""" + return self.table.index.tolist() + + @property + def columns(self) -> list: + """Return all columns for the dataset""" + return self.table.columns.tolist() + def get_data(self, row: Union[str, int], col: str) -> np.ndarray: """Since the dataset might contain pointers to external files, data retrieval is more complicated than just indexing the `table` attribute. This method provides an end-point for seamlessly @@ -453,7 +480,7 @@ def _get_cache_path(self, column: str, value: str) -> Optional[str]: return self._path_to_hash[column][value] def size(self): - return len(self), len(self.table.columns) + return self.no_datapoints, self.no_columns def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: """ @@ -499,6 +526,26 @@ def fn(path): table[c] = table[c].apply(fn) return table + def __getitem__(self, item): + """Allows for indexing the dataset directly""" + ret = self.table.loc[item] + if isinstance(ret, pd.Series): + # Load the data from the pointer columns + + if len(ret) == self.no_columns: + # Returning a row + ret = ret.to_dict() + for k in ret.keys(): + ret[k] = self.get_data(item, k) + + if len(ret) == self.no_rows: + # Returnin a column + if self.annotations[ret.name].is_pointer: + ret = [self.get_data(item, ret.name) for item in ret.index] + return np.array(ret) + + return ret + def _repr_dict_(self) -> dict: """Utility function for pretty-printing to the command line and jupyter notebooks""" repr_dict = self.model_dump() diff --git a/polaris/utils/types.py b/polaris/utils/types.py index 9bc8f14f..0f850dc8 100644 --- a/polaris/utils/types.py +++ b/polaris/utils/types.py @@ -1,4 +1,5 @@ import json +from enum import Enum from typing import Annotated, Any, ClassVar, Literal, Optional, Union import fsspec @@ -145,3 +146,17 @@ def _validate_license_id(cls, m: "License"): "It is required to then also specify the name and reference." ) return m + + +class TargetType(Enum): + """The high-level classification of different targets.""" + + REGRESSION = "regression" + CLASSIFICATION = "classification" + + +class TaskType(Enum): + """The high-level classification of different tasks.""" + + MULTI_TASK = "multi_task" + SINGLE_TASK = "single_task" diff --git a/tests/conftest.py b/tests/conftest.py index bfa6d8a1..f5b0077b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -60,6 +60,7 @@ def test_dataset(test_data, test_org_owner): user_attributes={"attributeA": "valueA", "attributeB": "valueB"}, owner=test_org_owner, license=License(id="MIT"), + curation_reference="https://www.example.com", ) @@ -152,6 +153,7 @@ def test_multi_task_benchmark(test_dataset): split=(train_indices, test_indices), target_cols=["expt", "calc"], input_cols="smiles", + target_types={"expt": "regression"}, ) From 2c05e02dc81bfe7a6dfe3bd5b7da9be442719bec Mon Sep 17 00:00:00 2001 From: cwognum Date: Mon, 4 Dec 2023 14:23:39 -0500 Subject: [PATCH 2/6] Rename from no_ to n_ --- polaris/dataset/_dataset.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index 346019a0..b1673b80 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -159,13 +159,13 @@ def _compute_checksum(table): @computed_field @property - def no_rows(self) -> int: + def n_rows(self) -> int: """The number of datapoints in the dataset.""" return len(self.rows) @computed_field @property - def no_columns(self) -> int: + def n_columns(self) -> int: """The number of columns in the dataset.""" return len(self.columns) @@ -480,7 +480,7 @@ def _get_cache_path(self, column: str, value: str) -> Optional[str]: return self._path_to_hash[column][value] def size(self): - return self.no_datapoints, self.no_columns + return self.n_datapoints, self.n_columns def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: """ @@ -532,13 +532,13 @@ def __getitem__(self, item): if isinstance(ret, pd.Series): # Load the data from the pointer columns - if len(ret) == self.no_columns: + if len(ret) == self.n_columns: # Returning a row ret = ret.to_dict() for k in ret.keys(): ret[k] = self.get_data(item, k) - if len(ret) == self.no_rows: + if len(ret) == self.n_rows: # Returnin a column if self.annotations[ret.name].is_pointer: ret = [self.get_data(item, ret.name) for item in ret.index] From a666b0868bb2966599e616497fe0df72a5cd7745 Mon Sep 17 00:00:00 2001 From: cwognum Date: Mon, 4 Dec 2023 14:47:16 -0500 Subject: [PATCH 3/6] updated dataset code snippet --- docs/quickstart.md | 19 +++++++++++++++++-- polaris/dataset/_dataset.py | 15 +++++++++++---- 2 files changed, 28 insertions(+), 6 deletions(-) diff --git a/docs/quickstart.md b/docs/quickstart.md index 2553d22e..64335bf9 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -53,8 +53,23 @@ Similarly, you can easily access a dataset. ```python import polaris as po -dataset = po.load_dataset("org_or_user/name") -dataset.get_data(col=..., row=...) +# Load the dataset from the hub +dataset = po.load_dataset("polaris/hello-world-dataset") + +# Get information on the dataset size +dataset.size() + +# Load a datapoint in memory +dataset.get_data( + row=dataset.rows[0], + col=dataset.columns[0], +) + +# Or, similarly: +dataset[dataset.rows[0], dataset.columns[0]] + +# Get the first 10 rows in memory +dataset[:10] ``` ## Core concepts diff --git a/polaris/dataset/_dataset.py b/polaris/dataset/_dataset.py index b1673b80..fb4c4dd8 100644 --- a/polaris/dataset/_dataset.py +++ b/polaris/dataset/_dataset.py @@ -160,7 +160,7 @@ def _compute_checksum(table): @computed_field @property def n_rows(self) -> int: - """The number of datapoints in the dataset.""" + """The number of rows in the dataset.""" return len(self.rows) @computed_field @@ -480,7 +480,7 @@ def _get_cache_path(self, column: str, value: str) -> Optional[str]: return self._path_to_hash[column][value] def size(self): - return self.n_datapoints, self.n_columns + return self.rows, self.n_columns def _split_index_from_path(self, path: str) -> Tuple[str, Optional[int]]: """ @@ -539,11 +539,18 @@ def __getitem__(self, item): ret[k] = self.get_data(item, k) if len(ret) == self.n_rows: - # Returnin a column + # Returning a column if self.annotations[ret.name].is_pointer: ret = [self.get_data(item, ret.name) for item in ret.index] return np.array(ret) + # Returning a dataframe + if isinstance(ret, pd.DataFrame): + for c in ret.columns: + if self.annotations[c].is_pointer: + ret[c] = [self.get_data(item, c) for item in ret.index] + return ret + return ret def _repr_dict_(self) -> dict: @@ -557,7 +564,7 @@ def _repr_html_(self): return dict2html(self._repr_dict_()) def __len__(self): - return len(self.table) + return self.n_rows def __repr__(self): return json.dumps(self._repr_dict_(), indent=2) From 7190074fa12a642af7fd1080c6d066e25faf5876 Mon Sep 17 00:00:00 2001 From: cwognum Date: Mon, 4 Dec 2023 14:52:50 -0500 Subject: [PATCH 4/6] Change from no_ to n_ for benchmark too --- polaris/benchmark/_base.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index b40608c8..f1feb256 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -311,37 +311,37 @@ def _compute_checksum(dataset, target_cols, input_cols, split, metrics): @computed_field @property - def no_train_datapoints(self) -> int: + def n_train_datapoints(self) -> int: """The size of the train set.""" return len(self.split[0]) @computed_field @property - def no_test_sets(self) -> int: + def n_test_sets(self) -> int: """The number of test sets""" return len(self.split[1]) if isinstance(self.split[1], dict) else 1 @computed_field @property - def no_test_datapoints(self) -> dict[str, int]: + def n_test_datapoints(self) -> dict[str, int]: """The size of (each of) the test set(s).""" - if self.no_test_sets == 1: + if self.n_test_sets == 1: return {"test": len(self.split[1])} else: return {k: len(v) for k, v in self.split[1].items()} @computed_field @property - def no_classes(self) -> dict[str, int]: + def n_classes(self) -> dict[str, int]: """The number of classes for each of the target columns.""" - no_classes = {} + n_classes = {} for target in self.target_cols: target_type = self.target_types[target] if target_type is None or target_type == TargetType.REGRESSION: - no_classes[target] = None + n_classes[target] = None else: - no_classes[target] = self.dataset.loc[:, target].nunique() - return no_classes + n_classes[target] = self.dataset.loc[:, target].nunique() + return n_classes @computed_field @property From b916df5db725dcd6c559584a44d97e29c0bb69da Mon Sep 17 00:00:00 2001 From: Cas Wognum Date: Mon, 4 Dec 2023 22:03:31 -0500 Subject: [PATCH 5/6] Minor: More informative error in client --- polaris/hub/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/polaris/hub/client.py b/polaris/hub/client.py index fe982365..dafd2ba0 100644 --- a/polaris/hub/client.py +++ b/polaris/hub/client.py @@ -319,7 +319,10 @@ def get_dataset(self, owner: Union[str, HubOwner], name: str) -> Dataset: # This should be a 307 redirect with the signed URL if storage_response.status_code != 307: - raise PolarisHubError("Could not get signed URL from Polaris Hub.") + try: + storage_response.raise_for_status() + except HTTPStatusError as error: + raise PolarisHubError("Could not get signed URL from Polaris Hub.") from error storage_response = storage_response.json() url = storage_response["url"] From f1e717aac95a16a491d0b6b906407be3c90fec9d Mon Sep 17 00:00:00 2001 From: cwognum Date: Tue, 5 Dec 2023 10:22:37 -0500 Subject: [PATCH 6/6] Don't save None's to nClasses --- polaris/benchmark/_base.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/polaris/benchmark/_base.py b/polaris/benchmark/_base.py index f1feb256..a5440610 100644 --- a/polaris/benchmark/_base.py +++ b/polaris/benchmark/_base.py @@ -338,9 +338,8 @@ def n_classes(self) -> dict[str, int]: for target in self.target_cols: target_type = self.target_types[target] if target_type is None or target_type == TargetType.REGRESSION: - n_classes[target] = None - else: - n_classes[target] = self.dataset.loc[:, target].nunique() + continue + n_classes[target] = self.dataset.loc[:, target].nunique() return n_classes @computed_field