Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(job): tune eval info classification ui for standalone #1219

Merged
merged 1 commit into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion client/starwhale/api/_impl/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,6 @@ def __init__(
self.columns = columns
self.keep_none = keep_none

logger.debug(f"scan enter, table size:{len(tables)}")
infos: List[TableInfo] = []
for table_desc in tables:
table = self.tables.get(table_desc.table_name, None)
Expand Down
55 changes: 39 additions & 16 deletions client/starwhale/api/_impl/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
confusion_matrix,
cohen_kappa_score,
classification_report,
multilabel_confusion_matrix,
)

from starwhale.utils.flatten import do_flatten_dict
Expand Down Expand Up @@ -46,45 +47,66 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
cr = classification_report(
y_true, y_pred, output_dict=True, labels=all_labels
)
_summary_m = ["accuracy", "macro avg", "weighted avg"]
_r["summary"] = {k: cr.get(k) for k in _summary_m}
_summary_m = ["accuracy", "micro avg", "weighted avg", "macro avg"]
_r["summary"] = {}
for k in _summary_m:
v = cr.get(k)
if not v:
continue
_r["summary"][k] = v

if show_hamming_loss:
_r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred)
if show_cohen_kappa_score:
_r["summary"]["cohen_kappa_score"] = cohen_kappa_score(y_true, y_pred)

_record_summary = do_flatten_dict(_r["summary"])
_record_summary["kind"] = _r["kind"]
handler.evaluation.log_metrics(_record_summary)

_r["labels"] = {}
for k, v in cr.items():
if k in _summary_m:
mcm = multilabel_confusion_matrix(
y_true, y_pred, labels=all_labels
).tolist()

labels = all_labels or sorted([k for k in cr.keys() if k not in _summary_m])
for _label, matrix in zip(labels, mcm):
_label = str(_label)
_report = cr.get(_label)
if not _report:
continue
_r["labels"][k] = v
handler.evaluation.log("labels", id=k, **v)

_report.update(
{
"TP-True Positive": matrix[0][0],
"TN-True Negative": matrix[0][1],
"FP-False Positive": matrix[1][0],
"FN-False Negative": matrix[1][1],
}
)

_r["labels"][_label] = _report
handler.evaluation.log("labels", id=_label, **_report)

# TODO: tune performance, use intermediated result
cm = confusion_matrix(
y_true, y_pred, labels=all_labels, normalize=confusion_matrix_normalize
)

_cm_list = cm.tolist()
_r["confusion_matrix"] = {"binarylabel": _cm_list}

for idx, _pa in enumerate(_cm_list):
for _idx, _pa in enumerate(_cm_list):
handler.evaluation.log(
"confusion_matrix/binarylabel",
id=idx,
id=_idx,
**{str(_id): _v for _id, _v in enumerate(_pa)},
)

if show_hamming_loss:
_r["summary"]["hamming_loss"] = hamming_loss(y_true, y_pred)
if show_cohen_kappa_score:
_r["summary"]["cohen_kappa_score"] = cohen_kappa_score(y_true, y_pred)

if show_roc_auc and all_labels is not None and y_true and y_pr:
_r["roc_auc"] = {}
for _idx, _label in enumerate(all_labels):
_ra_value = _calculate_roc_auc(y_true, y_pr, _label, _idx)
_r["roc_auc"][_label] = _ra_value
_r["roc_auc"][str(_label)] = _ra_value

for _fpr, _tpr, _threshold in zip(
_ra_value["fpr"], _ra_value["tpr"], _ra_value["thresholds"]
Expand All @@ -96,8 +118,9 @@ def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Dict[str, t.Any]:
tpr=_tpr,
threshold=_threshold,
)

handler.evaluation.log(
"roc_auc/summary", id=_label, auc=_ra_value["auc"]
"labels", id=str(_label), auc=_ra_value["auc"]
)
return _r

Expand Down
3 changes: 1 addition & 2 deletions client/starwhale/api/_impl/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def _starwhale_internal_run_cmp(self) -> None:
_iter = PPLResultIterator(
data=self.evaluation.get_results(), deserializer=self.deserialize
)
output = self.cmp(_iter)
self.cmp(_iter)
except Exception as e:
self._sw_logger.exception(f"cmp exception: {e}")
self._timeline_writer.write(
Expand All @@ -231,7 +231,6 @@ def _starwhale_internal_run_cmp(self) -> None:
raise
else:
self._timeline_writer.write({"time": now, "status": True, "exception": ""})
self._sw_logger.debug(f"cmp result:{output}")

@_record_status # type: ignore
def _starwhale_internal_run_ppl(self) -> None:
Expand Down
11 changes: 2 additions & 9 deletions client/starwhale/core/eval/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _print_labels() -> None:
_k, *(f"{float(_v[_k2]):.4f}" for _k2 in keys if _k2 != "id")
)

console.rule(f"[bold green]{report['kind'].upper()} Report")
console.rule(f"[bold green]{report['kind'].upper()} Label Metrics Report")
console.print(table)

def _print_confusion_matrix() -> None:
Expand All @@ -220,15 +220,8 @@ def _print_confusion_matrix() -> None:
*[f"{float(bl[i]):.4f}" for i in bl if i != "id"],
)

mtable = Table(box=box.SIMPLE)
mtable.add_column("Label", style="cyan")
for n in ("TP", "TN", "FP", "FN"):
mtable.add_column(n)
for idx, ml in enumerate(cm.get("multilabel", [])):
mtable.add_row(sort_label_names[idx], *[str(_) for _ in ml[0] + ml[1]])

console.rule(f"[bold green]{report['kind'].upper()} Confusion Matrix")
console.print(self.comparison(mtable, btable))
console.print(btable)

_print_labels()
_print_confusion_matrix()
Expand Down
51 changes: 45 additions & 6 deletions client/tests/sdk/test_metric.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,39 @@
import random
from unittest import skip, TestCase
from unittest import TestCase
from unittest.mock import MagicMock

from starwhale.api._impl.metric import multi_classification


@skip
class TestMultiClassificationMetric(TestCase):
def test_multi_classification_metric(self):
def _cmp():
def test_multi_classification_metric(
self,
) -> None:
def _cmp(handler, data):
return (
["a", "b", "c", "d", "a", "a", "a"],
["b", "b", "d", "d", "a", "a", "b"],
)

eval_handler = MagicMock()
rt = multi_classification(
confusion_matrix_normalize="all",
show_hamming_loss=True,
show_cohen_kappa_score=True,
show_roc_auc=False,
all_labels=["a", "b", "c", "d"],
)(_cmp)(eval_handler, None)
assert rt["kind"] == "multi_classification"

metric_call = eval_handler.evaluation.log_metrics.call_args[0][0]
assert "weighted avg/precision" in metric_call
assert list(rt["labels"].keys()) == ["a", "b", "c", "d"]
assert "confusion_matrix/binarylabel" not in rt

def test_multi_classification_metric_with_pa(
self,
) -> None:
def _cmp(handler, data):
return (
[1, 2, 3, 4, 5, 6, 7, 8, 9],
[1, 3, 2, 4, 5, 6, 7, 8, 9],
Expand All @@ -17,18 +43,31 @@ def _cmp():
],
)

eval_handler = MagicMock()
rt = multi_classification(
confusion_matrix_normalize="all",
show_hamming_loss=True,
show_cohen_kappa_score=True,
show_roc_auc=True,
all_labels=[i for i in range(1, 10)],
)(_cmp)()
)(_cmp)(eval_handler, None)

assert rt["kind"] == "multi_classification"
assert "accuracy" in rt["summary"]
assert "macro avg" in rt["summary"]
assert len(rt["labels"]) == 9
assert "binarylabel" in rt["confusion_matrix"]
assert "multilabel" in rt["confusion_matrix"]
assert len(rt["roc_auc"]) == 9

metric_call = eval_handler.evaluation.log_metrics.call_args[0][0]
assert isinstance(metric_call, dict)
assert metric_call["kind"] == rt["kind"]
assert "macro avg/f1-score" in metric_call

log_calls = set(
[args[0][0] for args in eval_handler.evaluation.log.call_args_list]
)
assert "labels" in log_calls
assert "confusion_matrix/binarylabel" in log_calls
assert "roc_auc/9" in log_calls
assert "roc_auc/1" in log_calls