Skip to content
5 changes: 3 additions & 2 deletions polaris/benchmark/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidBenchmarkError, PolarisChecksumError
from polaris.utils.misc import listit, to_lower_camel
from polaris.utils.types import DataFormat, PredictionsType, SplitType
from polaris.utils.types import DataFormat, PredictionsType, SplitType, AccessType

ColumnsType = Union[str, list[str]]

Expand Down Expand Up @@ -376,6 +376,7 @@ def upload_to_hub(
env_file: Optional[Union[str, os.PathLike]] = None,
settings: Optional[PolarisHubSettings] = None,
cache_auth_token: bool = True,
access: Optional[AccessType] = "private",
**kwargs: dict,
):
"""
Expand All @@ -387,7 +388,7 @@ def upload_to_hub(
with PolarisHubClient(
env_file=env_file, settings=settings, cache_auth_token=cache_auth_token, **kwargs
) as client:
return client.upload_benchmark(self)
return client.upload_benchmark(self, access)

def to_json(self, destination: str) -> str:
"""Save the benchmark to a destination directory as a JSON file.
Expand Down
5 changes: 3 additions & 2 deletions polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from polaris.utils.errors import InvalidDatasetError, PolarisChecksumError
from polaris.utils.io import get_zarr_root, robust_copy
from polaris.utils.misc import to_lower_camel
from polaris.utils.types import HttpUrlString, License
from polaris.utils.types import HttpUrlString, License, AccessType

# Constants
_SUPPORTED_TABLE_EXTENSIONS = ["parquet"]
Expand Down Expand Up @@ -207,6 +207,7 @@ def upload_to_hub(
env_file: Optional[Union[str, os.PathLike]] = None,
settings: Optional[PolarisHubSettings] = None,
cache_auth_token: bool = True,
access: Optional[AccessType] = "private",
**kwargs: dict,
):
"""
Expand All @@ -218,7 +219,7 @@ def upload_to_hub(
with PolarisHubClient(
env_file=env_file, settings=settings, cache_auth_token=cache_auth_token, **kwargs
) as client:
return client.upload_dataset(self)
return client.upload_dataset(self, access)

@classmethod
def from_zarr(cls, path: str) -> "Dataset":
Expand Down
60 changes: 52 additions & 8 deletions polaris/evaluate/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,51 @@
from typing import Callable

import numpy as np
from pydantic import BaseModel
from sklearn.metrics import accuracy_score, mean_absolute_error, mean_squared_error
from pydantic import BaseModel, Field
from sklearn.metrics import (
accuracy_score,
mean_absolute_error,
mean_squared_error,
r2_score,
explained_variance_score,
)
from sklearn.metrics import (
f1_score,
matthews_corrcoef,
roc_auc_score,
average_precision_score,
cohen_kappa_score,
)
from scipy import stats

from polaris.utils.types import DirectionType


def pearsonr(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a pearson r correlation"""
return stats.pearsonr(y_true, y_pred).statistic


def spearman(y_true: np.ndarray, y_pred: np.ndarray):
"""Calculate a Spearman correlation"""
return stats.spearmanr(y_true, y_pred).statistic


class MetricInfo(BaseModel):
"""
Metric metadata

Attributes:
fn: The callable that actually computes the metric
fn: The callable that actually computes the metric.
is_multitask: Whether the metric expects a single set of predictions or a dict of predictions.
kwargs: Additional parameters required for the metric.
direction: The direction for ranking of the metric, "max" for maximization and "min" for minimization.
"""

fn: Callable
is_multitask: bool = False
kwargs: dict = Field(default_factory=dict)
direction: DirectionType


class Metric(Enum):
Expand All @@ -28,12 +58,26 @@ class Metric(Enum):
"""

# TODO (cwognum):
# - Add support for more metrics
# - Any preprocessing needed? For example changing the shape / dtype? Converting from torch tensors or lists?

mean_absolute_error = MetricInfo(fn=mean_absolute_error)
mean_squared_error = MetricInfo(fn=mean_squared_error)
accuracy = MetricInfo(fn=accuracy_score)
# regression
mean_absolute_error = MetricInfo(fn=mean_absolute_error, direction="min")
mean_squared_error = MetricInfo(fn=mean_squared_error, direction="min")
r2 = MetricInfo(fn=r2_score, direction="max")
pearsonr = MetricInfo(fn=pearsonr, direction="max")
spearmanr = MetricInfo(fn=spearman, direction="max")
explained_var = MetricInfo(fn=explained_variance_score, direction="max")

# classification
accuracy = MetricInfo(fn=accuracy_score, direction="max")
f1 = MetricInfo(fn=f1_score, kwargs={"average": "binary"}, direction="max")
f1_macro = MetricInfo(fn=f1_score, kwargs={"average": "marco"}, direction="max")
f1_micro = MetricInfo(fn=f1_score, kwargs={"average": "micro"}, direction="max")
roc_auc = MetricInfo(fn=roc_auc_score, direction="max")
pr_auc = MetricInfo(fn=average_precision_score, direction="max")
mcc = MetricInfo(fn=matthews_corrcoef, direction="max")
cohen_kappa = MetricInfo(fn=cohen_kappa_score, direction="max")
# TODO: adding metrics for multiclass tasks

@property
def fn(self) -> Callable:
Expand All @@ -55,7 +99,7 @@ def score(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
assert metric.score(y_true=first, y_pred=second) == metric(y_true=first, y_pred=second)
```
"""
return self.fn(y_true, y_pred)
return self.fn(y_true, y_pred, **self.value.kwargs)

def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
"""For convenience, make metrics callable"""
Expand Down
5 changes: 3 additions & 2 deletions polaris/evaluate/_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from polaris.utils.dict2html import dict2html
from polaris.utils.errors import InvalidResultError
from polaris.utils.misc import to_lower_camel
from polaris.utils.types import HttpUrlString, HubOwner, HubUser
from polaris.utils.types import HttpUrlString, HubOwner, HubUser, AccessType

# Define some helpful type aliases
TestLabelType = str
Expand Down Expand Up @@ -119,6 +119,7 @@ def upload_to_hub(
env_file: Optional[Union[str, os.PathLike]] = None,
settings: Optional[PolarisHubSettings] = None,
cache_auth_token: bool = True,
access: Optional[AccessType] = "private",
**kwargs: dict,
):
"""
Expand All @@ -130,7 +131,7 @@ def upload_to_hub(
with PolarisHubClient(
env_file=env_file, settings=settings, cache_auth_token=cache_auth_token, **kwargs
) as client:
return client.upload_results(self)
return client.upload_results(self, access)

def _repr_dict_(self) -> dict:
"""Utility function for pretty-printing to the command line and jupyter notebooks"""
Expand Down
18 changes: 13 additions & 5 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from polaris.utils import fs
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.errors import PolarisHubError, PolarisUnauthorizedError
from polaris.utils.types import HubOwner
from polaris.utils.types import HubOwner, AccessType

_HTTPX_SSL_ERROR_CODE = "[SSL: CERTIFICATE_VERIFY_FAILED]"

Expand Down Expand Up @@ -373,7 +373,7 @@ def get_benchmark(self, owner: Union[str, HubOwner], name: str) -> BenchmarkSpec
)
return benchmark_cls(**response)

def upload_results(self, results: BenchmarkResults):
def upload_results(self, results: BenchmarkResults, access: AccessType = "private"):
"""Upload the results to the Polaris Hub.

Info: Owner
Expand All @@ -394,10 +394,12 @@ def upload_results(self, results: BenchmarkResults):

Args:
results: The results to upload.
access: Grant public or private access to result
"""

# Get the serialized model data-structure
result_json = results.model_dump(by_alias=True, exclude_none=True)
result_json["access"] = access

# Make a request to the hub
url = f"/benchmark/{results.benchmark_owner}/{results.benchmark_name}/result"
Expand All @@ -411,7 +413,9 @@ def upload_results(self, results: BenchmarkResults):
logger.success(f"Your result has been successfully uploaded to the Hub. View it here: {result_url}")
return response

def upload_dataset(self, dataset: Dataset, timeout: TimeoutTypes = (10, 200)):
def upload_dataset(
self, dataset: Dataset, access: AccessType = "private", timeout: TimeoutTypes = (10, 200)
):
"""Upload the dataset to the Polaris Hub.

Info: Owner
Expand All @@ -426,6 +430,7 @@ def upload_dataset(self, dataset: Dataset, timeout: TimeoutTypes = (10, 200)):

Args:
dataset: The dataset to upload.
access: Grant public or private access to result
timeout: Request timeout values. User can modify the value when uploading large dataset as needed.
"""

Expand All @@ -448,6 +453,7 @@ def upload_dataset(self, dataset: Dataset, timeout: TimeoutTypes = (10, 200)):
self.settings.hub_url, f"/storage/dataset/{dataset.owner}/{dataset.name}/table.parquet"
),
}
dataset_json["access"] = access
url = f"/dataset/{dataset.owner}/{dataset.name}"
response = self._base_request_to_hub(url=url, method="PUT", json=dataset_json)

Expand Down Expand Up @@ -487,7 +493,7 @@ def upload_dataset(self, dataset: Dataset, timeout: TimeoutTypes = (10, 200)):

return response

def upload_benchmark(self, benchmark: BenchmarkSpecification):
def upload_benchmark(self, benchmark: BenchmarkSpecification, access: AccessType = "private"):
"""Upload the benchmark to the Polaris Hub.

Info: Owner
Expand All @@ -506,12 +512,14 @@ def upload_benchmark(self, benchmark: BenchmarkSpecification):

Args:
benchmark: The benchmark to upload.
access: Grant public or private access to result
"""

# Get the serialized data-model
# We exclude the dataset as we expect it to exist on the hub already.
benchmark_json = benchmark.model_dump(exclude=["dataset"], exclude_none=True, by_alias=True)
benchmark_json["datasetName"] = f"{benchmark.dataset.owner}/{benchmark.dataset.name}"
benchmark_json["datasetArtifactId"] = f"{benchmark.dataset.owner}/{benchmark.dataset.name}"
benchmark_json["access"] = access

url = f"/benchmark/{benchmark.owner}/{benchmark.name}"
response = self._base_request_to_hub(url=url, method="PUT", json=benchmark_json)
Expand Down
11 changes: 11 additions & 0 deletions polaris/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@
This is useful for interactions with httpx and authlib, who have their own URL types.
"""

DirectionType: TypeAlias = Literal["min", "max"]
"""
The direction of any variable to be sorted.
This can be used to sort the metric score, indicate the optmization direction of endpoint.
"""

AccessType: TypeAlias = Literal["public", "private"]
"""
Type to specify access to a dataset, benchmark or result in the Hub.
"""


class HubOwner(BaseModel):
"""An owner of an artifact on the Polaris Hub
Expand Down
66 changes: 62 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ def _populate_group(group):

@pytest.fixture(scope="module")
def test_data():
return dm.data.freesolv()[:100]
data = dm.data.freesolv()[:100]
# set an abitrary threshold for testing purpose.
data["CLASS_expt"] = data["expt"].gt(0).astype(int).values
data["CLASS_calc"] = data["calc"].gt(0).astype(int).values
return data


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -76,22 +80,52 @@ def test_single_task_benchmark(test_dataset):
return SingleTaskBenchmarkSpecification(
name="single-task-benchmark",
dataset=test_dataset,
metrics=["mean_absolute_error", "mean_squared_error"],
metrics=[
"mean_absolute_error",
"mean_squared_error",
"r2",
"spearmanr",
"pearsonr",
"explained_var",
],
main_metric="mean_absolute_error",
split=(train_indices, test_indices),
target_cols="expt",
input_cols="smiles",
)


@pytest.fixture(scope="module")
def test_single_task_benchmark_clf(test_dataset):
train_indices = list(range(90))
test_indices = list(range(90, 100))
return SingleTaskBenchmarkSpecification(
name="single-task-benchmark",
dataset=test_dataset,
main_metric="accuracy",
metrics=["accuracy", "f1", "roc_auc", "pr_auc", "mcc", "cohen_kappa"],
split=(train_indices, test_indices),
target_cols="CLASS_expt",
input_cols="smiles",
)


@pytest.fixture(scope="module")
def test_single_task_benchmark_multiple_test_sets(test_dataset):
train_indices = list(range(90))
test_indices = {"test_1": list(range(90, 95)), "test_2": list(range(95, 100))}
return SingleTaskBenchmarkSpecification(
name="single-task-benchmark",
dataset=test_dataset,
metrics=["mean_absolute_error", "mean_squared_error"],
metrics=[
"mean_absolute_error",
"mean_squared_error",
"r2",
"spearmanr",
"pearsonr",
"explained_var",
],
main_metric="r2",
split=(train_indices, test_indices),
target_cols="expt",
input_cols="smiles",
Expand All @@ -106,8 +140,32 @@ def test_multi_task_benchmark(test_dataset):
return MultiTaskBenchmarkSpecification(
name="multi-task-benchmark",
dataset=test_dataset,
metrics=["mean_absolute_error"],
main_metric="mean_absolute_error",
metrics=[
"mean_absolute_error",
"mean_squared_error",
"r2",
"spearmanr",
"pearsonr",
"explained_var",
],
split=(train_indices, test_indices),
target_cols=["expt", "calc"],
input_cols="smiles",
)


@pytest.fixture(scope="module")
def test_multi_task_benchmark_clf(test_dataset):
# For the sake of simplicity, just use a small set of indices
train_indices = list(range(90))
test_indices = list(range(90, 100))
return MultiTaskBenchmarkSpecification(
name="multi-task-benchmark",
dataset=test_dataset,
main_metric="accuracy",
metrics=["accuracy", "f1", "roc_auc", "pr_auc", "mcc", "cohen_kappa"],
split=(train_indices, test_indices),
target_cols=["CLASS_expt", "CLASS_calc"],
input_cols="smiles",
)
Loading