From e63fac20e70eee4a153ce372152cd133cf2f9605 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15931259256@163.com> Date: Sun, 28 Aug 2022 13:58:15 +0800 Subject: [PATCH 1/3] cmp results use multi table in datastore --- client/starwhale/api/_impl/metric.py | 7 +----- client/starwhale/api/_impl/model.py | 32 ++++++++++++++++++++++++++- client/starwhale/api/_impl/wrapper.py | 25 +++++++++++++++++++++ client/starwhale/core/eval/model.py | 12 +++++++++- client/starwhale/core/eval/view.py | 26 +++++++++++++--------- 5 files changed, 83 insertions(+), 19 deletions(-) diff --git a/client/starwhale/api/_impl/metric.py b/client/starwhale/api/_impl/metric.py index 72dd344877..110c06d173 100644 --- a/client/starwhale/api/_impl/metric.py +++ b/client/starwhale/api/_impl/metric.py @@ -10,11 +10,8 @@ confusion_matrix, cohen_kappa_score, classification_report, - multilabel_confusion_matrix, ) -from starwhale.utils.flatten import do_flatten_dict - class MetricKind: MultiClassification = "multi_classification" @@ -50,11 +47,9 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: cm = confusion_matrix( y_true, y_pred, labels=all_labels, normalize=confusion_matrix_normalize ) - mcm = multilabel_confusion_matrix(y_true, y_pred, labels=all_labels) _r["confusion_matrix"] = { "binarylabel": cm.tolist(), - "multilabel": mcm.tolist(), } if show_hamming_loss: _r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred) @@ -67,7 +62,7 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]: _r["roc_auc"][_label] = _calculate_roc_auc( y_true, y_pr, _label, _idx ) - return do_flatten_dict(_r) + return _r return _wrapper diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index da1fb46307..b17f1f5989 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -26,6 +26,7 @@ from starwhale.consts.env import SWEnv from starwhale.utils.error import FieldTypeOrValueError from starwhale.api._impl.job import Context +from starwhale.utils.flatten import do_flatten_dict from starwhale.core.job.model import STATUS from starwhale.core.eval.store import EvaluationStorage from starwhale.api._impl.dataset import DataField, get_data_loader @@ -254,7 +255,36 @@ def _starwhale_internal_run_cmp(self) -> None: else: self._timeline_writer.write({"time": now, "status": True, "exception": ""}) self._sw_logger.debug(f"cmp result:{output}") - self.evaluation.log_metrics(output) + + self.evaluation.log_metrics(do_flatten_dict(output["summary"])) + self.evaluation.log_metrics({"kind": output["kind"]}) + + for i, label in output["labels"].items(): + self.evaluation.log_table("labels", label, id=i) + + _binary_label = output["confusion_matrix"]["binarylabel"] + for _label, _probability in enumerate(_binary_label): + self.evaluation.log_table( + "confusion_matrix/binarylabel", + { + **dict( + id=str(_label), + ), + **{str(k): v for k, v in enumerate(_probability)}, + }, + ) + + for _label, _roc_auc in output["roc_auc"].items(): + _id = 0 + for _fpr, _tpr, _threshold in zip( + _roc_auc["fpr"], _roc_auc["tpr"], _roc_auc["thresholds"] + ): + self.evaluation.log_table( + f"roc_auc/{_label}", + dict(id=str(_id), fpr=_fpr, tpr=_tpr, threshold=_threshold), + ) + _id += 1 + self.evaluation.log_metrics({f"roc_auc/{_label}": _roc_auc["auc"]}) @_record_status # type: ignore def _starwhale_internal_run_ppl(self) -> None: diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index 7f12184365..4ec1efb962 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -23,6 +23,8 @@ def close(self) -> None: def _log(self, table_name: str, record: Dict[str, Any]) -> None: with self._lock: + if table_name not in self._writers.keys(): + self._writers.setdefault(table_name, None) writer = self._writers[table_name] if writer is None: writer = data_store.TableWriter(table_name) @@ -59,6 +61,7 @@ def log_metrics( self, metrics: Optional[Dict[str, Any]] = None, **kwargs: Any ) -> None: record = {"id": self.eval_id} + # TODO: without if else? if metrics is not None: for k, v in metrics.items(): k = k.lower() @@ -69,6 +72,19 @@ def log_metrics( record[k.lower()] = v self._log(self._summary_table_name, record) + def log_table( + self, table_name: str, metrics: Optional[Dict[str, Any]] = None, **kwargs: Any + ) -> None: + record = {} + if metrics is not None: + for k, v in metrics.items(): + k = k.lower() + record[k] = v + + for k, v in kwargs.items(): + record[k.lower()] = v + self._log(f"project/{self.project}/eval/{self.eval_id}/{table_name}", record) + def get_results(self) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( [data_store.TableDesc(self._results_table_name)] @@ -83,6 +99,15 @@ def get_metrics(self) -> Dict[str, Any]: return {} + def get_results_from_table(self, table_name: str) -> Iterator[Dict[str, Any]]: + return self._data_store.scan_tables( + [ + data_store.TableDesc( + f"project/{self.project}/eval/{self.eval_id}/{table_name}" + ) + ] + ) + class Dataset(Logger): def __init__(self, dataset_id: str, project: str = "") -> None: diff --git a/client/starwhale/core/eval/model.py b/client/starwhale/core/eval/model.py index 8e92b7a9d2..10910880e7 100644 --- a/client/starwhale/core/eval/model.py +++ b/client/starwhale/core/eval/model.py @@ -174,7 +174,17 @@ def _get_report(self) -> t.Dict[str, t.Any]: f"datastore path:{str(self.sw_config.datastore_dir)}, eval_id:{self.store.id}" ) _datastore = wrapper.Evaluation() - return _datastore.get_metrics() + _labels = list(_datastore.get_results_from_table("labels")) + return dict( + summary=_datastore.get_metrics(), + labels={str(i): l for i, l in enumerate(_labels)}, + confusion_matrix=dict( + binarylabel=list( + _datastore.get_results_from_table("confusion_matrix/binarylabel") + ) + ), + kind=_datastore.get_metrics()["kind"], + ) @staticmethod def _do_flatten_summary(summary: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: diff --git a/client/starwhale/core/eval/view.py b/client/starwhale/core/eval/view.py index c8aa5a3211..e66d724ffa 100644 --- a/client/starwhale/core/eval/view.py +++ b/client/starwhale/core/eval/view.py @@ -175,16 +175,17 @@ def _r(_tree: t.Any, _obj: t.Any) -> None: _tree.add(str(_obj)) for _k, _v in _obj.items(): - if isinstance(_v, (list, tuple)): - _k = f"{_k}: [green]{'|'.join(_v)}" - elif isinstance(_v, dict): - _k = _k - else: - _k = f"{_k}: [green]{_v:.4f}" + if _k != "id": + if isinstance(_v, (list, tuple)): + _k = f"{_k}: [green]{'|'.join(_v)}" + elif isinstance(_v, dict) or isinstance(_v, str): + _k = _k + else: + _k = f"{_k}: [green]{_v:.4f}" - _ntree = _tree.add(_k) - if isinstance(_v, dict): - _r(_ntree, _v) + _ntree = _tree.add(_k) + if isinstance(_v, dict): + _r(_ntree, _v) tree = Tree("Summary") _r(tree, report["summary"]) @@ -198,7 +199,7 @@ def _r(_tree: t.Any, _obj: t.Any) -> None: table.add_column(_k.capitalize()) for _k, _v in labels.items(): - table.add_row(_k, *(f"{_v[_k2]:.4f}" for _k2 in keys)) + table.add_row(_k, *(f"{float(_v[_k2]):.4f}" for _k2 in keys)) console.rule(f"[bold green]{report['kind'].upper()} Report") console.print(self.comparison(tree, table)) @@ -213,7 +214,10 @@ def _print_confusion_matrix() -> None: for n in sort_label_names: btable.add_column(n) for idx, bl in enumerate(cm.get("binarylabel", [])): - btable.add_row(sort_label_names[idx], *[f"{_:.4f}" for _ in bl]) + btable.add_row( + sort_label_names[idx], + *[f"{float(bl[i]):.4f}" for i in bl if i != "id"], + ) mtable = Table(box=box.SIMPLE) mtable.add_column("Label", style="cyan") From 6d3f75403ac96586344b7ee80cd8f09616c47420 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15931259256@163.com> Date: Sun, 28 Aug 2022 15:01:23 +0800 Subject: [PATCH 2/3] optimise code --- client/starwhale/api/_impl/model.py | 19 +++++++++--------- client/starwhale/api/_impl/wrapper.py | 29 ++++++++++----------------- client/starwhale/core/eval/model.py | 6 ++---- client/starwhale/core/eval/view.py | 23 ++++++++++++--------- 4 files changed, 35 insertions(+), 42 deletions(-) diff --git a/client/starwhale/api/_impl/model.py b/client/starwhale/api/_impl/model.py index b17f1f5989..0d35490f4a 100644 --- a/client/starwhale/api/_impl/model.py +++ b/client/starwhale/api/_impl/model.py @@ -260,18 +260,14 @@ def _starwhale_internal_run_cmp(self) -> None: self.evaluation.log_metrics({"kind": output["kind"]}) for i, label in output["labels"].items(): - self.evaluation.log_table("labels", label, id=i) + self.evaluation.log("labels", id=i, **label) _binary_label = output["confusion_matrix"]["binarylabel"] for _label, _probability in enumerate(_binary_label): - self.evaluation.log_table( + self.evaluation.log( "confusion_matrix/binarylabel", - { - **dict( - id=str(_label), - ), - **{str(k): v for k, v in enumerate(_probability)}, - }, + id=str(_label), + **{str(k): v for k, v in enumerate(_probability)}, ) for _label, _roc_auc in output["roc_auc"].items(): @@ -279,9 +275,12 @@ def _starwhale_internal_run_cmp(self) -> None: for _fpr, _tpr, _threshold in zip( _roc_auc["fpr"], _roc_auc["tpr"], _roc_auc["thresholds"] ): - self.evaluation.log_table( + self.evaluation.log( f"roc_auc/{_label}", - dict(id=str(_id), fpr=_fpr, tpr=_tpr, threshold=_threshold), + id=str(_id), + fpr=_fpr, + tpr=_tpr, + threshold=_threshold, ) _id += 1 self.evaluation.log_metrics({f"roc_auc/{_label}": _roc_auc["auc"]}) diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index 4ec1efb962..b6c953af90 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -3,6 +3,7 @@ import threading from typing import Any, Dict, List, Union, Iterator, Optional +from starwhale.consts import VERSION_PREFIX_CNT from starwhale.consts.env import SWEnv from . import data_store @@ -23,7 +24,7 @@ def close(self) -> None: def _log(self, table_name: str, record: Dict[str, Any]) -> None: with self._lock: - if table_name not in self._writers.keys(): + if table_name not in self._writers: self._writers.setdefault(table_name, None) writer = self._writers[table_name] if writer is None: @@ -46,11 +47,14 @@ def __init__(self, eval_id: Optional[str] = None): self.project = os.getenv(SWEnv.project) if self.project is None: raise RuntimeError(f"{SWEnv.project} is not set") - self._results_table_name = f"project/{self.project}/eval/{self.eval_id}/results" - self._summary_table_name = f"project/{self.project}/eval/summary" + self._results_table_name = self._get_datastore_table_name("results") + self._summary_table_name = self._get_datastore_table_name("summary") self._init_writers([self._results_table_name, self._summary_table_name]) self._data_store = data_store.get_data_store() + def _get_datastore_table_name(self, table_name: str) -> str: + return f"project/{self.project}/eval/{self.eval_id[:VERSION_PREFIX_CNT]}/{self.eval_id}/{table_name}" + def log_result(self, data_id: str, result: Any, **kwargs: Any) -> None: record = {"id": data_id, "result": result} for k, v in kwargs.items(): @@ -72,18 +76,11 @@ def log_metrics( record[k.lower()] = v self._log(self._summary_table_name, record) - def log_table( - self, table_name: str, metrics: Optional[Dict[str, Any]] = None, **kwargs: Any - ) -> None: + def log(self, table_name: str, **kwargs: Any) -> None: record = {} - if metrics is not None: - for k, v in metrics.items(): - k = k.lower() - record[k] = v - for k, v in kwargs.items(): record[k.lower()] = v - self._log(f"project/{self.project}/eval/{self.eval_id}/{table_name}", record) + self._log(self._get_datastore_table_name(table_name), record) def get_results(self) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( @@ -99,13 +96,9 @@ def get_metrics(self) -> Dict[str, Any]: return {} - def get_results_from_table(self, table_name: str) -> Iterator[Dict[str, Any]]: + def get(self, table_name: str) -> Iterator[Dict[str, Any]]: return self._data_store.scan_tables( - [ - data_store.TableDesc( - f"project/{self.project}/eval/{self.eval_id}/{table_name}" - ) - ] + [data_store.TableDesc(self._get_datastore_table_name(table_name))] ) diff --git a/client/starwhale/core/eval/model.py b/client/starwhale/core/eval/model.py index 10910880e7..ebdc723a46 100644 --- a/client/starwhale/core/eval/model.py +++ b/client/starwhale/core/eval/model.py @@ -174,14 +174,12 @@ def _get_report(self) -> t.Dict[str, t.Any]: f"datastore path:{str(self.sw_config.datastore_dir)}, eval_id:{self.store.id}" ) _datastore = wrapper.Evaluation() - _labels = list(_datastore.get_results_from_table("labels")) + _labels = list(_datastore.get("labels")) return dict( summary=_datastore.get_metrics(), labels={str(i): l for i, l in enumerate(_labels)}, confusion_matrix=dict( - binarylabel=list( - _datastore.get_results_from_table("confusion_matrix/binarylabel") - ) + binarylabel=list(_datastore.get("confusion_matrix/binarylabel")) ), kind=_datastore.get_metrics()["kind"], ) diff --git a/client/starwhale/core/eval/view.py b/client/starwhale/core/eval/view.py index e66d724ffa..f2748457d0 100644 --- a/client/starwhale/core/eval/view.py +++ b/client/starwhale/core/eval/view.py @@ -175,17 +175,20 @@ def _r(_tree: t.Any, _obj: t.Any) -> None: _tree.add(str(_obj)) for _k, _v in _obj.items(): - if _k != "id": - if isinstance(_v, (list, tuple)): - _k = f"{_k}: [green]{'|'.join(_v)}" - elif isinstance(_v, dict) or isinstance(_v, str): - _k = _k - else: - _k = f"{_k}: [green]{_v:.4f}" + if _k == "id": + continue + if isinstance(_v, (list, tuple)): + _k = f"{_k}: [green]{'|'.join(_v)}" + elif isinstance(_v, dict): + _k = _k + elif isinstance(_v, str): + _k = f"{_k}:{_v}" + else: + _k = f"{_k}: [green]{_v:.4f}" - _ntree = _tree.add(_k) - if isinstance(_v, dict): - _r(_ntree, _v) + _ntree = _tree.add(_k) + if isinstance(_v, dict): + _r(_ntree, _v) tree = Tree("Summary") _r(tree, report["summary"]) From abe85bf489f31d8dacf04b43706db733da6e87b9 Mon Sep 17 00:00:00 2001 From: gaoxinxing <15931259256@163.com> Date: Sun, 28 Aug 2022 15:10:57 +0800 Subject: [PATCH 3/3] fix summary table name error --- client/starwhale/api/_impl/wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/starwhale/api/_impl/wrapper.py b/client/starwhale/api/_impl/wrapper.py index b6c953af90..e15c10a8ab 100644 --- a/client/starwhale/api/_impl/wrapper.py +++ b/client/starwhale/api/_impl/wrapper.py @@ -48,7 +48,7 @@ def __init__(self, eval_id: Optional[str] = None): if self.project is None: raise RuntimeError(f"{SWEnv.project} is not set") self._results_table_name = self._get_datastore_table_name("results") - self._summary_table_name = self._get_datastore_table_name("summary") + self._summary_table_name = f"project/{self.project}/eval/summary" self._init_writers([self._results_table_name, self._summary_table_name]) self._data_store = data_store.get_data_store()