Skip to content

Commit

Permalink
fix pep8 add anoter merge test
Browse files Browse the repository at this point in the history
  • Loading branch information
lalmei committed Apr 6, 2021
1 parent f487c85 commit e3c9336
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
6 changes: 3 additions & 3 deletions src/whylogs/core/metrics/model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ def merge(self, other):
"""
if other is None:
return self

model_type =ModelType.UNKNOWN

model_type = ModelType.UNKNOWN

if (self.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION)):
if other.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION):
model_type = other.model_type
Expand All @@ -108,7 +109,6 @@ def merge(self, other):
else:
model_type = self.model_type


return ModelMetrics(
confusion_matrix=self.confusion_matrix.merge(other.confusion_matrix) if self.confusion_matrix else None,
regression_metrics=self.regression_metrics.merge(
Expand Down
11 changes: 10 additions & 1 deletion tests/unit/core/metrics/test_model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def tests_model_metrics_to_protobuf_regression():

def test_merge_none():
metrics = ModelMetrics()
metrics.merge(None)
assert metrics.merge(None) == metrics


def test_merge_metrics_with_none_confusion_matrix():
Expand All @@ -95,7 +95,15 @@ def test_merge_metrics_model():
other.regression_metrics = None
new_metrics = metrics.merge(other)
assert new_metrics.model_type==ModelType.REGRESSION
assert new_metrics.confusion_matrix is None

# keep initial model type during merge
metrics = ModelMetrics(model_type=ModelType.REGRESSION)
other = ModelMetrics(model_type=ModelType.CLASSIFICATION)
other.regression_metrics = None
new_metrics = metrics.merge(other)
assert new_metrics.model_type==ModelType.REGRESSION
assert new_metrics.confusion_matrix is None

def test_merge_metrics_with_none_regression_matrix():
metrics = ModelMetrics()
Expand All @@ -111,6 +119,7 @@ def test_merge_metrics_with_none_confusion_matrix():
other.regression_metrics = None

new_metrics = metrics.merge(other)
assert new_metrics.model_type == ModelType.UNKNOWN


def test_model_metrics_init():
Expand Down

0 comments on commit e3c9336

Please sign in to comment.