From 4aa6ce6570809cb9ad9ab0cfdde65fd98656c9fe Mon Sep 17 00:00:00 2001 From: "Leandro G. Almeida" Date: Mon, 5 Apr 2021 17:25:07 -0700 Subject: [PATCH] added more tests and make sure merger doesnt affect model type --- src/whylogs/core/metrics/model_metrics.py | 15 +++-- tests/unit/core/metrics/test_model_metrics.py | 57 +++++++++++++++++-- 2 files changed, 62 insertions(+), 10 deletions(-) diff --git a/src/whylogs/core/metrics/model_metrics.py b/src/whylogs/core/metrics/model_metrics.py index 5837384e96..1ebea4c110 100644 --- a/src/whylogs/core/metrics/model_metrics.py +++ b/src/whylogs/core/metrics/model_metrics.py @@ -99,17 +99,22 @@ def merge(self, other): if other.confusion_matrix is None and other.regression_metrics is None: # TODO: return a copy instead return self + if self.confusion_matrix is None and self.regression_metrics is None: return other - if self.model_type is None or other.model_type is None: - model_type = ModelType.UNKNOWN - elif other.model_type != self.model_type: - model_type = ModelType.UNKNOWN + if (self.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION)): + if other.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION): + model_type = other.model_type + + elif other.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION): + if self.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION): + model_type = self.model_type else: model_type = self.model_type return ModelMetrics( confusion_matrix=self.confusion_matrix.merge(other.confusion_matrix) if self.confusion_matrix else None, - regression_metrics=self.regression_metrics.merge(other.regression_metrics)if self.regression_metrics else None, + regression_metrics=self.regression_metrics.merge( + other.regression_metrics)if self.regression_metrics else None, model_type=model_type) diff --git a/tests/unit/core/metrics/test_model_metrics.py b/tests/unit/core/metrics/test_model_metrics.py index 3384dcf14e..0be59bb3e2 100644 --- a/tests/unit/core/metrics/test_model_metrics.py +++ b/tests/unit/core/metrics/test_model_metrics.py @@ -26,7 +26,7 @@ def tests_model_metrics(): jdx].floats.count == expected_1[idx][jdx] -def tests_model_metrics_to_protobuf(): +def tests_model_metrics_to_protobuf_classification(): mod_met = ModelMetrics(model_type=ModelType.CLASSIFICATION) targets_1 = ["cat", "dog", "pig"] @@ -39,7 +39,46 @@ def tests_model_metrics_to_protobuf(): message = mod_met.to_protobuf() - ModelMetrics.from_protobuf(message) + model_metrics = ModelMetrics.from_protobuf(message) + assert model_metrics.model_type == ModelType.CLASSIFICATION + assert model_metrics.confusion_matrix.labels == ["cat", "dog", "pig"] + + +def tests_no_metrics_to_protobuf_classification(): + mod_met = ModelMetrics(model_type=ModelType.CLASSIFICATION) + + + message = mod_met.to_protobuf() + + model_metrics = ModelMetrics.from_protobuf(message) + assert model_metrics.model_type == ModelType.CLASSIFICATION + +def tests_no_metrics_to_protobuf_regression():\ + + mod_met = ModelMetrics(model_type=ModelType.REGRESSION) + assert mod_met.model_type == ModelType.REGRESSION + message = mod_met.to_protobuf() + + + model_metrics = ModelMetrics.from_protobuf(message) + assert model_metrics.model_type == ModelType.REGRESSION + +def tests_model_metrics_to_protobuf_regression(): + + + regression_model = ModelMetrics(model_type=ModelType.REGRESSION) + + + targets_1 = [0.1, 0.3, 0.4] + predictions_1 = [0.5, 0.5, 0.5] + regression_model.compute_regression_metrics(predictions_1, targets_1) + regression_message = regression_model.to_protobuf() + model_metrics_from_message = ModelMetrics.from_protobuf(regression_message) + assert model_metrics_from_message.model_type == ModelType.REGRESSION + + + + def test_merge_none(): @@ -54,11 +93,19 @@ def test_merge_metrics_with_none_confusion_matrix(): new_metrics = metrics.merge(other) +def test_merge_metrics_model(): + metrics = ModelMetrics() + other = ModelMetrics() + other.regression_metrics = None + new_metrics = metrics.merge(other) + + def test_merge_metrics_with_none_regression_matrix(): metrics = ModelMetrics() other = ModelMetrics() other.regression_metrics = None - new_metrics= metrics.merge(other) + new_metrics = metrics.merge(other) + def test_merge_metrics_with_none_confusion_matrix(): metrics = ModelMetrics() @@ -68,9 +115,9 @@ def test_merge_metrics_with_none_confusion_matrix(): new_metrics = metrics.merge(other) + def test_model_metrics_init(): reg_met = RegressionMetrics() - conf_ma= ConfusionMatrix() + conf_ma = ConfusionMatrix() with pytest.raises(NotImplementedError): metrics = ModelMetrics(confusion_matrix=conf_ma, regression_metrics=reg_met) -