diff --git a/client/starwhale/api/_impl/metric.py b/client/starwhale/api/_impl/metric.py index 72dd344877..110c06d173 100644 --- a/client/starwhale/api/_impl/metric.py +++ b/client/starwhale/api/_impl/metric.py @@ -10,11 +10,8 @@ confusion_matrix, cohen_kappa_score, classification_report, - multilabel_confusion_matrix, ) -from starwhale.utils.flatten import do_flatten_dict - class MetricKind: MultiClassification = "multi_classification" @@ -50,11 +47,9 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: cm = confusion_matrix( y_true, y_pred, labels=all_labels, normalize=confusion_matrix_normalize ) - mcm = multilabel_confusion_matrix(y_true, y_pred, labels=all_labels) _r["confusion_matrix"] = { "binarylabel": cm.tolist(), - "multilabel": mcm.tolist(), } if show_hamming_loss: _r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred) @@ -67,7 +62,7 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: _r["roc_auc"][_label] = _calculate_roc_auc( y_true, y_pr, _label, _idx ) - return do_flatten_dict(_r) + return _r return _wrapper diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index da1fb46307..0d35490f4a 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -26,6 +26,7 @@ from starwhale.consts.env import SWEnv from starwhale.utils.error import FieldTypeOrValueError from starwhale.api._impl.job import Context +from starwhale.utils.flatten import do_flatten_dict from starwhale.core.job.model import STATUS from starwhale.core.eval.store import EvaluationStorage from starwhale.api._impl.dataset import DataField, get_data_loader @@ -254,7 +255,35 @@ def _starwhale_internal_run_cmp(self) -> None: else: self._timeline_writer.write({"time": now, "status": True, "exception": ""}) self._sw_logger.debug(f"cmp result:{output}") - self.evaluation.log_metrics(output) + + self.evaluation.log_metrics(do_flatten_dict(output["summary"])) + self.evaluation.log_metrics({"kind": output["kind"]}) + + for i, label in output["labels"].items(): + self.evaluation.log("labels", id=i, **label) + + _binary_label = output["confusion_matrix"]["binarylabel"] + for _label, _probability in enumerate(_binary_label): + self.evaluation.log( + "confusion_matrix/binarylabel", + id=str(_label), + **{str(k): v for k, v in enumerate(_probability)}, + ) + + for _label, _roc_auc in output["roc_auc"].items(): + _id = 0 + for _fpr, _tpr, _threshold in zip( + _roc_auc["fpr"], _roc_auc["tpr"], _roc_auc["thresholds"] + ): + self.evaluation.log( + f"roc_auc/{_label}", + id=str(_id), + fpr=_fpr, + tpr=_tpr, + threshold=_threshold, + ) + _id += 1 + self.evaluation.log_metrics({f"roc_auc/{_label}": _roc_auc["auc"]}) @_record_status # type: ignore def _starwhale_internal_run_ppl(self) -> None: diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index 7f12184365..e15c10a8ab 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -3,6 +3,7 @@ import threading from typing import Any, Dict, List, Union, Iterator, Optional +from starwhale.consts import VERSION_PREFIX_CNT from starwhale.consts.env import SWEnv from . import data_store @@ -23,6 +24,8 @@ def close(self) -> None: def _log(self, table_name: str, record: Dict[str, Any]) -> None: with self._lock: + if table_name not in self._writers: + self._writers.setdefault(table_name, None) writer = self._writers[table_name] if writer is None: writer = data_store.TableWriter(table_name) @@ -44,11 +47,14 @@ def __init__(self, eval_id: Optional[str] = None): self.project = os.getenv(SWEnv.project) if self.project is None: raise RuntimeError(f"{SWEnv.project} is not set") - self._results_table_name = f"project/{self.project}/eval/{self.eval_id}/results" + self._results_table_name = self._get_datastore_table_name("results") self._summary_table_name = f"project/{self.project}/eval/summary" self._init_writers([self._results_table_name, self._summary_table_name]) self._data_store = data_store.get_data_store() + def _get_datastore_table_name(self, table_name: str) -> str: + return f"project/{self.project}/eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{table_name}" + def log_result(self, data_id: str, result: Any, **kwargs: Any) -> None: record = {"id": data_id, "result": result} for k, v in kwargs.items(): @@ -59,6 +65,7 @@ def log_metrics( self, metrics: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: record = {"id": self.eval_id} + # TODO: without if else? if metrics is not None: for k, v in metrics.items(): k = k.lower() @@ -69,6 +76,12 @@ def log_metrics( record[k.lower()] = v self._log(self._summary_table_name, record) + def log(self, table_name: str, **kwargs: Any) -> None: + record = {} + for k, v in kwargs.items(): + record[k.lower()] = v + self._log(self._get_datastore_table_name(table_name), record) + def get_results(self) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( [data_store.TableDesc(self._results_table_name)] @@ -83,6 +96,11 @@ def get_metrics(self) -> Dict[str, Any]: return {} + def get(self, table_name: str) -> Iterator[Dict[str, Any]]: + return self._data_store.scan_tables( + [data_store.TableDesc(self._get_datastore_table_name(table_name))] + ) + class Dataset(Logger): def __init__(self, dataset_id: str, project: str = "") -> None: diff --git a/client/starwhale/core/eval/model.py b/client/starwhale/core/eval/model.py index 8e92b7a9d2..ebdc723a46 100644 --- a/client/starwhale/core/eval/model.py +++ b/client/starwhale/core/eval/model.py @@ -174,7 +174,15 @@ def _get_report(self) -> t.Dict[str, t.Any]: f"datastore path:{str(self.sw_config.datastore_dir)}, eval_id:{self.store.id}" ) _datastore = wrapper.Evaluation() - return _datastore.get_metrics() + _labels = list(_datastore.get("labels")) + return dict( + summary=_datastore.get_metrics(), + labels={str(i): l for i, l in enumerate(_labels)}, + confusion_matrix=dict( + binarylabel=list(_datastore.get("confusion_matrix/binarylabel")) + ), + kind=_datastore.get_metrics()["kind"], + ) @staticmethod def _do_flatten_summary(summary: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: diff --git a/client/starwhale/core/eval/view.py b/client/starwhale/core/eval/view.py index c8aa5a3211..f2748457d0 100644 --- a/client/starwhale/core/eval/view.py +++ b/client/starwhale/core/eval/view.py @@ -175,10 +175,14 @@ def _r(_tree: t.Any, _obj: t.Any) -> None: _tree.add(str(_obj)) for _k, _v in _obj.items(): + if _k == "id": + continue if isinstance(_v, (list, tuple)): _k = f"{_k}: [green]{'|'.join(_v)}" elif isinstance(_v, dict): _k = _k + elif isinstance(_v, str): + _k = f"{_k}:{_v}" else: _k = f"{_k}: [green]{_v:.4f}" @@ -198,7 +202,7 @@ def _r(_tree: t.Any, _obj: t.Any) -> None: table.add_column(_k.capitalize()) for _k, _v in labels.items(): - table.add_row(_k, *(f"{_v[_k2]:.4f}" for _k2 in keys)) + table.add_row(_k, *(f"{float(_v[_k2]):.4f}" for _k2 in keys)) console.rule(f"[bold green]{report['kind'].upper()} Report") console.print(self.comparison(tree, table)) @@ -213,7 +217,10 @@ def _print_confusion_matrix() -> None: for n in sort_label_names: btable.add_column(n) for idx, bl in enumerate(cm.get("binarylabel", [])): - btable.add_row(sort_label_names[idx], *[f"{_:.4f}" for _ in bl]) + btable.add_row( + sort_label_names[idx], + *[f"{float(bl[i]):.4f}" for i in bl if i != "id"], + ) mtable = Table(box=box.SIMPLE) mtable.add_column("Label", style="cyan")