Skip to content

Commit

Permalink
tune eval info classification ui for standalone
Browse files Browse the repository at this point in the history
  • Loading branch information
tianweidut committed Sep 16, 2022
1 parent cafb52b commit 4f257e9
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 34 deletions.
1 change: 0 additions & 1 deletion client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def __init__(
self.columns = columns
self.keep_none = keep_none

logger.debug(f"scan enter, table size:{len(tables)}")
infos: List[TableInfo] = []
for table_desc in tables:
table = self.tables.get(table_desc.table_name, None)
Expand Down
55 changes: 39 additions & 16 deletions client/starwhale/api/_impl/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
confusion_matrix,
cohen_kappa_score,
classification_report,
multilabel_confusion_matrix,
)

from starwhale.utils.flatten import do_flatten_dict
Expand Down Expand Up @@ -46,45 +47,66 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
cr = classification_report(
y_true, y_pred, output_dict=True, labels=all_labels
)
_summary_m = ["accuracy", "macro avg", "weighted avg"]
_r["summary"] = {k: cr.get(k) for k in _summary_m}
_summary_m = ["accuracy", "micro avg", "weighted avg", "macro avg"]
_r["summary"] = {}
for k in _summary_m:
v = cr.get(k)
if not v:
continue
_r["summary"][k] = v

if show_hamming_loss:
_r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred)
if show_cohen_kappa_score:
_r["summary"]["cohen_kappa_score"] = cohen_kappa_score(y_true, y_pred)

_record_summary = do_flatten_dict(_r["summary"])
_record_summary["kind"] = _r["kind"]
handler.evaluation.log_metrics(_record_summary)

_r["labels"] = {}
for k, v in cr.items():
if k in _summary_m:
mcm = multilabel_confusion_matrix(
y_true, y_pred, labels=all_labels
).tolist()

labels = all_labels or sorted([k for k in cr.keys() if k not in _summary_m])
for _label, matrix in zip(labels, mcm):
_label = str(_label)
_report = cr.get(_label)
if not _report:
continue
_r["labels"][k] = v
handler.evaluation.log("labels", id=k, **v)

_report.update(
{
"TP-True Positive": matrix[0][0],
"TN-True Negative": matrix[0][1],
"FP-False Positive": matrix[1][0],
"FN-False Negative": matrix[1][1],
}
)

_r["labels"][_label] = _report
handler.evaluation.log("labels", id=_label, **_report)

# TODO: tune performance, use intermediated result
cm = confusion_matrix(
y_true, y_pred, labels=all_labels, normalize=confusion_matrix_normalize
)

_cm_list = cm.tolist()
_r["confusion_matrix"] = {"binarylabel": _cm_list}

for idx, _pa in enumerate(_cm_list):
for _idx, _pa in enumerate(_cm_list):
handler.evaluation.log(
"confusion_matrix/binarylabel",
id=idx,
id=_idx,
**{str(_id): _v for _id, _v in enumerate(_pa)},
)

if show_hamming_loss:
_r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred)
if show_cohen_kappa_score:
_r["summary"]["cohen_kappa_score"] = cohen_kappa_score(y_true, y_pred)

if show_roc_auc and all_labels is not None and y_true and y_pr:
_r["roc_auc"] = {}
for _idx, _label in enumerate(all_labels):
_ra_value = _calculate_roc_auc(y_true, y_pr, _label, _idx)
_r["roc_auc"][_label] = _ra_value
_r["roc_auc"][str(_label)] = _ra_value

for _fpr, _tpr, _threshold in zip(
_ra_value["fpr"], _ra_value["tpr"], _ra_value["thresholds"]
Expand All @@ -96,8 +118,9 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
tpr=_tpr,
threshold=_threshold,
)

handler.evaluation.log(
"roc_auc/summary", id=_label, auc=_ra_value["auc"]
"labels", id=str(_label), auc=_ra_value["auc"]
)
return _r

Expand Down
3 changes: 1 addition & 2 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _starwhale_internal_run_cmp(self) -> None:
_iter = PPLResultIterator(
data=self.evaluation.get_results(), deserializer=self.deserialize
)
output = self.cmp(_iter)
self.cmp(_iter)
except Exception as e:
self._sw_logger.exception(f"cmp exception: {e}")
self._timeline_writer.write(
Expand All @@ -231,7 +231,6 @@ def _starwhale_internal_run_cmp(self) -> None:
raise
else:
self._timeline_writer.write({"time": now, "status": True, "exception": ""})
self._sw_logger.debug(f"cmp result:{output}")

@_record_status # type: ignore
def _starwhale_internal_run_ppl(self) -> None:
Expand Down
11 changes: 2 additions & 9 deletions client/starwhale/core/eval/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _print_labels() -> None:
_k, *(f"{float(_v[_k2]):.4f}" for _k2 in keys if _k2 != "id")
)

console.rule(f"[bold green]{report['kind'].upper()} Report")
console.rule(f"[bold green]{report['kind'].upper()} Label Metrics Report")
console.print(table)

def _print_confusion_matrix() -> None:
Expand All @@ -220,15 +220,8 @@ def _print_confusion_matrix() -> None:
*[f"{float(bl[i]):.4f}" for i in bl if i != "id"],
)

mtable = Table(box=box.SIMPLE)
mtable.add_column("Label", style="cyan")
for n in ("TP", "TN", "FP", "FN"):
mtable.add_column(n)
for idx, ml in enumerate(cm.get("multilabel", [])):
mtable.add_row(sort_label_names[idx], *[str(_) for _ in ml[0] + ml[1]])

console.rule(f"[bold green]{report['kind'].upper()} Confusion Matrix")
console.print(self.comparison(mtable, btable))
console.print(btable)

_print_labels()
_print_confusion_matrix()
Expand Down
27 changes: 21 additions & 6 deletions client/tests/sdk/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import random
from unittest import skip, TestCase
from unittest import TestCase
from unittest.mock import MagicMock

from starwhale.api._impl.metric import multi_classification


@skip
class TestMultiClassificationMetric(TestCase):
def test_multi_classification_metric(self):
def _cmp():
def test_multi_classification_metric(
self,
) -> None:
def _cmp(handler, data):
return (
[1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 3, 2, 4, 5, 6, 7, 8, 9],
Expand All @@ -17,18 +19,31 @@ def _cmp():
],
)

eval_handler = MagicMock()
rt = multi_classification(
confusion_matrix_normalize="all",
show_hamming_loss=True,
show_cohen_kappa_score=True,
show_roc_auc=True,
all_labels=[i for i in range(1, 10)],
)(_cmp)()
)(_cmp)(eval_handler, None)

assert rt["kind"] == "multi_classification"
assert "accuracy" in rt["summary"]
assert "macro avg" in rt["summary"]
assert len(rt["labels"]) == 9
assert "binarylabel" in rt["confusion_matrix"]
assert "multilabel" in rt["confusion_matrix"]
assert len(rt["roc_auc"]) == 9

metric_call = eval_handler.evaluation.log_metrics.call_args[0][0]
assert isinstance(metric_call, dict)
assert metric_call["kind"] == rt["kind"]
assert "macro avg/f1-score" in metric_call

log_calls = set(
[args[0][0] for args in eval_handler.evaluation.log.call_args_list]
)
assert "labels" in log_calls
assert "confusion_matrix/binarylabel" in log_calls
assert "roc_auc/9" in log_calls
assert "roc_auc/1" in log_calls

0 comments on commit 4f257e9

Please sign in to comment.