Skip to content

Commit

Permalink
✅ add checks on fields and merge them accordingly
Browse files Browse the repository at this point in the history
  • Loading branch information
lalmei committed Mar 15, 2021
1 parent a2b33b5 commit 8dc266e
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 40 deletions.
12 changes: 6 additions & 6 deletions src/whylogs/core/metrics/model_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@ class ModelMetrics:
regression_metrics (RegressionMetrics): Regression Metrics keeps track of a common regression metrics in case the targets are continous.
"""

def __init__(self, confusion_matrix: ConfusionMatrix = ConfusionMatrix(),
regression_metrics: RegressionMetrics = RegressionMetrics(),
def __init__(self, confusion_matrix: ConfusionMatrix = None,
regression_metrics: RegressionMetrics = None,
model_type: ModelType = ModelType.UNKNOWN):
# if confusion_matrix is None:
# confusion_matrix = ConfusionMatrix()
if confusion_matrix is None:
confusion_matrix = ConfusionMatrix()
self.confusion_matrix = confusion_matrix
# if regression_metrics is None:
# regression_metrics = RegressionMetrics()
if regression_metrics is None:
regression_metrics = RegressionMetrics()
self.regression_metrics = regression_metrics
self.model_type = ModelType.UNKNOWN

Expand Down
24 changes: 15 additions & 9 deletions src/whylogs/core/metrics/regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,32 @@ def root_mean_squared_error(self):
return None
return math.sqrt(self.sum2_diff / self.count)

def merge(self, other_reg_met):
def merge(self, other):
"""
Merge two seperate confusion matrix which may or may not overlap in labels.
Args:
other_reg_met : regression metrics to merge with self
other : regression metrics to merge with self
Returns:
RegressionMetrics: merged regression metrics
"""

if self.count == 0:
return other_reg_met
if other_reg_met.count == 0:
return other
if other.count == 0:
return self

new_reg = RegressionMetrics()
new_reg.count = self.count + other_reg_met.count
new_reg.sum_abs_diff = self.sum_abs_diff + other_reg_met.sum_abs_diff
new_reg.sum_diff = self.sum_diff + other_reg_met.sum_diff
new_reg.sum2_diff = self.sum2_diff + other_reg_met.sum2_diff
if self.prediction_field != other.prediction_field:
raise ValueError("prediction fields differ")
if self.target_field != other.target_field:
raise ValueError("target fields differ")

new_reg = RegressionMetrics(prediction_field=self.prediction_field,
target_field=self.target_field)
new_reg.count = self.count + other.count
new_reg.sum_abs_diff = self.sum_abs_diff + other.sum_abs_diff
new_reg.sum_diff = self.sum_diff + other.sum_diff
new_reg.sum2_diff = self.sum2_diff + other.sum2_diff

return new_reg

Expand Down
8 changes: 4 additions & 4 deletions tests/unit/core/metrics/test_regression_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,19 @@ def test_empty_protobuf_should_return_none():
def test_merging():
regmet_sum = RegressionMetrics()

regmet = RegressionMetrics()
regmet = RegressionMetrics(prediction_field="predictions", target_field="targets")
df = pd.read_parquet(os.path.join(os.path.join(TEST_DATA_PATH, "metrics", "2021-02-12.parquet")))
regmet.add(df["predictions"].to_list(), df["targets"].to_list())
regmet_sum.add(df["predictions"].to_list(), df["targets"].to_list())

regmet_2 = RegressionMetrics()
regmet_2 = RegressionMetrics(prediction_field="predictions", target_field="targets")
df_2 = pd.read_parquet(os.path.join(os.path.join(TEST_DATA_PATH, "metrics", "2021-02-13.parquet")))
regmet_2.add(df_2["predictions"].to_list(), df_2["targets"].to_list())
regmet_sum.add(df_2["predictions"].to_list(), df_2["targets"].to_list())

merged_reg_metr = regmet.merge(regmet_2)

assert merged_reg_metr.count == regmet_sum.count
assert merged_reg_metr.mean_squared_error() == regmet_sum.mean_squared_error()
assert merged_reg_metr.root_mean_squared_error() == regmet_sum.root_mean_squared_error()
assert merged_reg_metr.mean_squared_error() == pytest.approx(regmet_sum.mean_squared_error(), 0.001)
assert merged_reg_metr.root_mean_squared_error() == pytest.approx(regmet_sum.root_mean_squared_error(), 0.001)
assert merged_reg_metr.mean_absolute_error() == pytest.approx(regmet_sum.mean_absolute_error(), 0.001)
39 changes: 18 additions & 21 deletions tests/unit/core/test_datasetprofile_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pytest



from whylogs.core import DatasetProfile
from whylogs.core.model_profile import ModelProfile

Expand Down Expand Up @@ -39,39 +38,37 @@ def test_read_java_protobuf():
assert lbl == labels[idx]



def test_parse_from_protobuf_with_regression():
dir_path = os.path.dirname(os.path.realpath(__file__))
prof= DatasetProfile.read_protobuf(os.path.join(
TEST_DATA_PATH, "metrics","regression_java.bin"))
prof = DatasetProfile.read_protobuf(os.path.join(
TEST_DATA_PATH, "metrics", "regression_java.bin"))
assert prof.name == 'my-model-name'
assert prof.model_profile is not None
assert prof.model_profile.metrics is not None
confusion_M = prof.model_profile.metrics.confusion_matrix
regression_met= prof.model_profile.metrics.regression_metrics
regression_met = prof.model_profile.metrics.regression_metrics
assert regression_met is not None
assert confusion_M is None
assert confusion_M is not None
# metrics
assert regression_met.count==89
assert regression_met.sum_abs_diff==pytest.approx(7649.1, 0.1)
assert regression_met.sum_diff==pytest.approx(522.7, 0.1)
assert regression_met.sum2_diff==pytest.approx(1021265.7, 0.1)
assert regression_met.count == 89
assert regression_met.sum_abs_diff == pytest.approx(7649.1, 0.1)
assert regression_met.sum_diff == pytest.approx(522.7, 0.1)
assert regression_met.sum2_diff == pytest.approx(1021265.7, 0.1)


def test_track_metrics():
import pandas as pd
mean_absolute_error=85.94534216005789
mean_squared_error =11474.89611670205
root_mean_squared_error =107.12094154133472
mean_absolute_error = 85.94534216005789
mean_squared_error = 11474.89611670205
root_mean_squared_error = 107.12094154133472

x1 = DatasetProfile(name="test")
df= pd.read_parquet(os.path.join(os.path.join(TEST_DATA_PATH,"metrics","2021-02-12.parquet")))
x1.track_metrics(df["predictions"].to_list(),df["targets"].to_list())
regression_metrics=x1.model_profile.metrics.regression_metrics
df = pd.read_parquet(os.path.join(os.path.join(TEST_DATA_PATH, "metrics", "2021-02-12.parquet")))
x1.track_metrics(df["predictions"].to_list(), df["targets"].to_list())
regression_metrics = x1.model_profile.metrics.regression_metrics
assert regression_metrics is not None
assert regression_metrics.count==len(df["predictions"].to_list())
assert regression_metrics.mean_squared_error()==pytest.approx(mean_squared_error,0.01)

assert regression_metrics.mean_absolute_error() == pytest.approx(mean_absolute_error,0.01)
assert regression_metrics.root_mean_squared_error() == pytest.approx(root_mean_squared_error,0.01)
assert regression_metrics.count == len(df["predictions"].to_list())
assert regression_metrics.mean_squared_error() == pytest.approx(mean_squared_error, 0.01)

assert regression_metrics.mean_absolute_error() == pytest.approx(mean_absolute_error, 0.01)
assert regression_metrics.root_mean_squared_error() == pytest.approx(root_mean_squared_error, 0.01)

0 comments on commit 8dc266e

Please sign in to comment.