Skip to content

Commit

Permalink
Merge 4aa6ce6 into 41c41b7
Browse files Browse the repository at this point in the history
  • Loading branch information
lalmei committed Apr 6, 2021
2 parents 41c41b7 + 4aa6ce6 commit 939224d
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 10 deletions.
15 changes: 10 additions & 5 deletions src/whylogs/core/metrics/model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,22 @@ def merge(self, other):
if other.confusion_matrix is None and other.regression_metrics is None:
# TODO: return a copy instead
return self

if self.confusion_matrix is None and self.regression_metrics is None:
return other

if self.model_type is None or other.model_type is None:
model_type = ModelType.UNKNOWN
elif other.model_type != self.model_type:
model_type = ModelType.UNKNOWN
if (self.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION)):
if other.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION):
model_type = other.model_type

elif other.model_type not in (ModelType.REGRESSION, ModelType.CLASSIFICATION):
if self.model_type in (ModelType.REGRESSION, ModelType.CLASSIFICATION):
model_type = self.model_type
else:
model_type = self.model_type

return ModelMetrics(
confusion_matrix=self.confusion_matrix.merge(other.confusion_matrix) if self.confusion_matrix else None,
regression_metrics=self.regression_metrics.merge(other.regression_metrics)if self.regression_metrics else None,
regression_metrics=self.regression_metrics.merge(
other.regression_metrics)if self.regression_metrics else None,
model_type=model_type)
57 changes: 52 additions & 5 deletions tests/unit/core/metrics/test_model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def tests_model_metrics():
jdx].floats.count == expected_1[idx][jdx]


def tests_model_metrics_to_protobuf():
def tests_model_metrics_to_protobuf_classification():
mod_met = ModelMetrics(model_type=ModelType.CLASSIFICATION)

targets_1 = ["cat", "dog", "pig"]
Expand All @@ -39,7 +39,46 @@ def tests_model_metrics_to_protobuf():

message = mod_met.to_protobuf()

ModelMetrics.from_protobuf(message)
model_metrics = ModelMetrics.from_protobuf(message)
assert model_metrics.model_type == ModelType.CLASSIFICATION
assert model_metrics.confusion_matrix.labels == ["cat", "dog", "pig"]


def tests_no_metrics_to_protobuf_classification():
mod_met = ModelMetrics(model_type=ModelType.CLASSIFICATION)


message = mod_met.to_protobuf()

model_metrics = ModelMetrics.from_protobuf(message)
assert model_metrics.model_type == ModelType.CLASSIFICATION

def tests_no_metrics_to_protobuf_regression():\

mod_met = ModelMetrics(model_type=ModelType.REGRESSION)
assert mod_met.model_type == ModelType.REGRESSION
message = mod_met.to_protobuf()


model_metrics = ModelMetrics.from_protobuf(message)
assert model_metrics.model_type == ModelType.REGRESSION

def tests_model_metrics_to_protobuf_regression():


regression_model = ModelMetrics(model_type=ModelType.REGRESSION)


targets_1 = [0.1, 0.3, 0.4]
predictions_1 = [0.5, 0.5, 0.5]
regression_model.compute_regression_metrics(predictions_1, targets_1)
regression_message = regression_model.to_protobuf()
model_metrics_from_message = ModelMetrics.from_protobuf(regression_message)
assert model_metrics_from_message.model_type == ModelType.REGRESSION






def test_merge_none():
Expand All @@ -54,11 +93,19 @@ def test_merge_metrics_with_none_confusion_matrix():
new_metrics = metrics.merge(other)


def test_merge_metrics_model():
metrics = ModelMetrics()
other = ModelMetrics()
other.regression_metrics = None
new_metrics = metrics.merge(other)


def test_merge_metrics_with_none_regression_matrix():
metrics = ModelMetrics()
other = ModelMetrics()
other.regression_metrics = None
new_metrics= metrics.merge(other)
new_metrics = metrics.merge(other)


def test_merge_metrics_with_none_confusion_matrix():
metrics = ModelMetrics()
Expand All @@ -68,9 +115,9 @@ def test_merge_metrics_with_none_confusion_matrix():

new_metrics = metrics.merge(other)


def test_model_metrics_init():
reg_met = RegressionMetrics()
conf_ma= ConfusionMatrix()
conf_ma = ConfusionMatrix()
with pytest.raises(NotImplementedError):
metrics = ModelMetrics(confusion_matrix=conf_ma, regression_metrics=reg_met)

0 comments on commit 939224d

Please sign in to comment.