Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/dvclive/live.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def log_sklearn_plot(self, kind, labels, predictions, name=None, **kwargs):
if name in self._plots:
data = self._plots[name]
elif kind in SKLEARN_PLOTS and SKLEARN_PLOTS[kind].could_log(val):
data = SKLEARN_PLOTS[kind](name, self.plots_dir)
data = SKLEARN_PLOTS[kind](name, self.plots_dir, **kwargs)
self._plots[data.name] = data
else:
raise InvalidPlotTypeError(name)
Expand Down
118 changes: 65 additions & 53 deletions src/dvclive/plots/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from pathlib import Path

from dvclive.serialize import dump_json
Expand All @@ -9,7 +10,7 @@ class SKLearnPlot(Data):
suffixes = [".json"]
subfolder = "sklearn"

def __init__(self, name: str, output_folder: str) -> None:
def __init__(self, name: str, output_folder: str, **kwargs) -> None: # noqa: ARG002
super().__init__(name, output_folder)
self.name = self.name.replace(".json", "")

Expand All @@ -25,22 +26,22 @@ def could_log(val: object) -> bool:
return True
return False

@staticmethod
def get_properties():
def get_properties(self):
raise NotImplementedError


class Roc(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "fpr",
"y": "tpr",
"title": "Receiver operating characteristic (ROC)",
"x_label": "False Positive Rate",
"y_label": "True Positive Rate",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "tpr",
"title": "Receiver operating characteristic (ROC)",
"x_label": "False Positive Rate",
"y_label": "True Positive Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -58,16 +59,17 @@ def dump(self, val, **kwargs) -> None:


class PrecisionRecall(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "recall",
"y": "precision",
"title": "Precision-Recall Curve",
"x_label": "Recall",
"y_label": "Precision",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "recall",
"y": "precision",
"title": "Precision-Recall Curve",
"x_label": "Recall",
"y_label": "Precision",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -86,16 +88,17 @@ def dump(self, val, **kwargs) -> None:


class Det(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "fpr",
"y": "fnr",
"title": "Detection error tradeoff (DET)",
"x_label": "False Positive Rate",
"y_label": "False Negative Rate",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "fpr",
"y": "fnr",
"title": "Detection error tradeoff (DET)",
"x_label": "False Positive Rate",
"y_label": "False Negative Rate",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import metrics
Expand All @@ -114,16 +117,24 @@ def dump(self, val, **kwargs) -> None:


class ConfusionMatrix(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "confusion",
"x": "actual",
"y": "predicted",
"title": "Confusion Matrix",
"x_label": "True Label",
"y_label": "Predicted Label",
}
DEFAULT_PROPERTIES = {
"template": "confusion",
"x": "actual",
"y": "predicted",
"title": "Confusion Matrix",
"x_label": "True Label",
"y_label": "Predicted Label",
}

def __init__(self, name: str, output_folder: str, **kwargs) -> None:
super().__init__(name, output_folder)
self.normalized = kwargs.get("normalized") or False

def get_properties(self):
properties = copy.deepcopy(self.DEFAULT_PROPERTIES)
if self.normalized:
properties["template"] = "confusion_normalized"
return properties

def dump(self, val, **kwargs) -> None: # noqa: ARG002
cm = [
Expand All @@ -134,16 +145,17 @@ def dump(self, val, **kwargs) -> None: # noqa: ARG002


class Calibration(SKLearnPlot):
@staticmethod
def get_properties():
return {
"template": "simple",
"x": "prob_pred",
"y": "prob_true",
"title": "Calibration Curve",
"x_label": "Mean Predicted Probability",
"y_label": "Fraction of Positives",
}
DEFAULT_PROPERTIES = {
"template": "simple",
"x": "prob_pred",
"y": "prob_true",
"title": "Calibration Curve",
"x_label": "Mean Predicted Probability",
"y_label": "Fraction of Positives",
}

def get_properties(self):
return copy.deepcopy(self.DEFAULT_PROPERTIES)

def dump(self, val, **kwargs) -> None:
from sklearn import calibration
Expand Down
17 changes: 6 additions & 11 deletions src/dvclive/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,12 @@ def get_plot_renderers(plots_folder, live):
name = file.relative_to(plots_folder).with_suffix("").as_posix()
properties = {}

if name in SKLEARN_PLOTS:
properties = SKLEARN_PLOTS[name].get_properties()
data_field = name
else:
# Plot with custom name
logged_plot = live._plots[name]
for default_name, plot_class in SKLEARN_PLOTS.items():
if isinstance(logged_plot, plot_class):
properties = plot_class.get_properties()
data_field = default_name
break
logged_plot = live._plots[name]
for default_name, plot_class in SKLEARN_PLOTS.items():
if isinstance(logged_plot, plot_class):
properties = logged_plot.get_properties()
data_field = default_name
break

data = json.loads(file.read_text())

Expand Down
17 changes: 17 additions & 0 deletions tests/test_dvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def test_make_dvcyaml_all_plots(tmp_dir):
live.log_metric("bar", 2)
live.log_image("img.png", Image.new("RGB", (10, 10), (250, 250, 250)))
live.log_sklearn_plot("confusion_matrix", [0, 0, 1, 1], [0, 1, 1, 0])
live.log_sklearn_plot(
"confusion_matrix",
[0, 0, 1, 1],
[0, 1, 1, 0],
name="confusion_matrix_normalized",
normalized=True,
)
live.log_sklearn_plot("roc", [0, 0, 1, 1], [0.0, 0.5, 0.5, 0.0], "custom_name_roc")
make_dvcyaml(live)

Expand All @@ -84,6 +91,16 @@ def test_make_dvcyaml_all_plots(tmp_dir):
"y_label": "Predicted Label",
},
},
{
"plots/sklearn/confusion_matrix_normalized.json": {
"template": "confusion_normalized",
"title": "Confusion Matrix",
"x": "actual",
"x_label": "True Label",
"y": "predicted",
"y_label": "Predicted Label",
}
},
{
"plots/sklearn/custom_name_roc.json": {
"template": "simple",
Expand Down
4 changes: 2 additions & 2 deletions tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
{"fpr": 1.0, "rev": "workspace", "threshold": 0.1, "tpr": 0.5},
{"fpr": 1.0, "rev": "workspace", "threshold": 0.0, "tpr": 1.0},
]
assert plot_renderer.properties == Roc.get_properties()
assert plot_renderer.properties == Roc.DEFAULT_PROPERTIES

for name in ("confusion_matrix", "train/cm"):
plot_renderer = plot_renderers_dict[name]
Expand All @@ -172,7 +172,7 @@ def test_get_plot_renderers(tmp_dir, mocker):
{"actual": "1", "rev": "workspace", "predicted": "0"},
{"actual": "1", "rev": "workspace", "predicted": "1"},
]
assert plot_renderer.properties == ConfusionMatrix.get_properties()
assert plot_renderer.properties == ConfusionMatrix.DEFAULT_PROPERTIES


def test_report_auto_doesnt_set_notebook(tmp_dir, mocker):
Expand Down