diff --git a/src/whylogs/core/metrics/model_metrics.py b/src/whylogs/core/metrics/model_metrics.py index 7d1eb379d4..e9506d913b 100644 --- a/src/whylogs/core/metrics/model_metrics.py +++ b/src/whylogs/core/metrics/model_metrics.py @@ -96,8 +96,9 @@ def merge(self, other): """ if other is None: return self - - model_type =ModelType.UNKNOWN + + model_type = ModelType.UNKNOWN + if (self.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION)): if other.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION): model_type = other.model_type @@ -108,7 +109,6 @@ def merge(self, other): else: model_type = self.model_type - return ModelMetrics( confusion_matrix=self.confusion_matrix.merge(other.confusion_matrix) if self.confusion_matrix else None, regression_metrics=self.regression_metrics.merge( diff --git a/tests/unit/core/metrics/test_model_metrics.py b/tests/unit/core/metrics/test_model_metrics.py index d9bd83edb8..ff2d235d57 100644 --- a/tests/unit/core/metrics/test_model_metrics.py +++ b/tests/unit/core/metrics/test_model_metrics.py @@ -79,7 +79,7 @@ def tests_model_metrics_to_protobuf_regression(): def test_merge_none(): metrics = ModelMetrics() - metrics.merge(None) + assert metrics.merge(None) == metrics def test_merge_metrics_with_none_confusion_matrix(): @@ -95,7 +95,15 @@ def test_merge_metrics_model(): other.regression_metrics = None new_metrics = metrics.merge(other) assert new_metrics.model_type==ModelType.REGRESSION + assert new_metrics.confusion_matrix is None + # keep initial model type during merge + metrics = ModelMetrics(model_type=ModelType.REGRESSION) + other = ModelMetrics(model_type=ModelType.CLASSIFICATION) + other.regression_metrics = None + new_metrics = metrics.merge(other) + assert new_metrics.model_type==ModelType.REGRESSION + assert new_metrics.confusion_matrix is None def test_merge_metrics_with_none_regression_matrix(): metrics = ModelMetrics() @@ -111,6 +119,7 @@ def test_merge_metrics_with_none_confusion_matrix(): other.regression_metrics = None new_metrics = metrics.merge(other) + assert new_metrics.model_type == ModelType.UNKNOWN def test_model_metrics_init():