Skip to content

Commit

Permalink
fix: filter NaN and None before model quality metrics (#70)
Browse files Browse the repository at this point in the history
* fix: filter NaN and None before model quality metrics

* fix: added null filtering to regression

* fix: added tests for binary and regression with nulls

* fix: added tests for multiclass, refactor pyspark functions

* fix: function in spark utils module, removed print

* fix: duplicated function in spark module
  • Loading branch information
SteZamboni committed Jul 4, 2024
1 parent 26ef8f7 commit 02659e7
Show file tree
Hide file tree
Showing 15 changed files with 1,507 additions and 65 deletions.
19 changes: 12 additions & 7 deletions spark/jobs/metrics/model_quality_regression_calculator.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import abs as pyspark_abs
import pyspark.sql.functions as F

from models.regression_model_quality import RegressionMetricType, ModelQualityRegression
from utils.models import ModelOut
from pyspark.ml.evaluation import RegressionEvaluator
from utils.spark import is_not_null


class ModelQualityRegressionCalculator:
Expand Down Expand Up @@ -35,12 +35,12 @@ def __eval_model_quality_metric(
# mape = 100 * (abs(actual - predicted) / actual) / n
_dataframe = dataframe.withColumn(
"mape",
pyspark_abs(
F.abs(
(
col(model.outputs.prediction.name)
- col(model.target.name)
F.col(model.outputs.prediction.name)
- F.col(model.target.name)
)
/ col(model.target.name)
/ F.col(model.target.name)
),
)
return _dataframe.agg({"mape": "avg"}).collect()[0][0] * 100
Expand Down Expand Up @@ -79,6 +79,11 @@ def __calc_mq_metrics(
def numerical_metrics(
model: ModelOut, dataframe: DataFrame, dataframe_count: int
) -> ModelQualityRegression:
# # drop row where prediction or ground_truth is null
dataframe_clean = dataframe.filter(
is_not_null(model.outputs.prediction.name) & is_not_null(model.target.name)
)
dataframe_clean_count = dataframe_clean.count()
return ModelQualityRegressionCalculator.__calc_mq_metrics(
model, dataframe, dataframe_count
model, dataframe_clean, dataframe_clean_count
)
92 changes: 55 additions & 37 deletions spark/jobs/utils/current_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
MulticlassClassificationEvaluator,
)
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import col
import pyspark.sql.functions as f
import pyspark.sql.functions as F

from metrics.data_quality_calculator import DataQualityCalculator
from metrics.drift_calculator import DriftCalculator
Expand All @@ -20,6 +19,7 @@
from models.reference_dataset import ReferenceDataset
from .misc import create_time_format
from .models import Granularity
from .spark import is_not_null


class CurrentMetricsService:
Expand Down Expand Up @@ -133,7 +133,11 @@ def __calc_bc_metrics(self) -> dict[str, float]:
def __calc_mc_metrics(self) -> dict[str, float]:
return {
label: self.__evaluate_multi_class_classification(
self.current.current, name
self.current.current.filter(
is_not_null(self.current.model.outputs.prediction.name)
& is_not_null(self.current.model.target.name)
),
name,
)
for (name, label) in self.model_quality_multiclass_classificator.items()
}
Expand Down Expand Up @@ -166,16 +170,21 @@ def __evaluate_multi_class_classification(
return float("nan")

def calculate_multiclass_model_quality_group_by_timestamp(self):
current_df_clean = self.current.current.filter(
is_not_null(self.current.model.outputs.prediction.name)
& is_not_null(self.current.model.target.name)
)

if self.current.model.granularity == Granularity.WEEK:
dataset_with_group = self.current.current.select(
dataset_with_group = current_df_clean.select(
[
self.current.model.outputs.prediction.name,
self.current.model.target.name,
f.date_format(
f.to_timestamp(
f.date_sub(
f.next_day(
f.date_format(
F.date_format(
F.to_timestamp(
F.date_sub(
F.next_day(
F.date_format(
self.current.model.timestamp.name,
create_time_format(
self.current.model.granularity
Expand All @@ -191,13 +200,13 @@ def calculate_multiclass_model_quality_group_by_timestamp(self):
]
)
else:
dataset_with_group = self.current.current.select(
dataset_with_group = current_df_clean.select(
[
self.current.model.outputs.prediction.name,
self.current.model.target.name,
f.date_format(
f.to_timestamp(
f.date_format(
F.date_format(
F.to_timestamp(
F.date_format(
self.current.model.timestamp.name,
create_time_format(self.current.model.granularity),
)
Expand All @@ -210,12 +219,12 @@ def calculate_multiclass_model_quality_group_by_timestamp(self):
list_of_time_group = (
dataset_with_group.select("time_group")
.distinct()
.orderBy(f.col("time_group").asc())
.orderBy(F.col("time_group").asc())
.rdd.flatMap(lambda x: x)
.collect()
)
array_of_groups = [
dataset_with_group.where(f.col("time_group") == x)
dataset_with_group.where(F.col("time_group") == x)
for x in list_of_time_group
]

Expand All @@ -233,16 +242,21 @@ def calculate_multiclass_model_quality_group_by_timestamp(self):
}

def calculate_binary_class_model_quality_group_by_timestamp(self):
current_df_clean = self.current.current.filter(
is_not_null(self.current.model.outputs.prediction_proba.name)
& is_not_null(self.current.model.target.name)
)

if self.current.model.granularity == Granularity.WEEK:
dataset_with_group = self.current.current.select(
dataset_with_group = current_df_clean.select(
[
self.current.model.outputs.prediction_proba.name,
self.current.model.target.name,
f.date_format(
f.to_timestamp(
f.date_sub(
f.next_day(
f.date_format(
F.date_format(
F.to_timestamp(
F.date_sub(
F.next_day(
F.date_format(
self.current.model.timestamp.name,
create_time_format(
self.current.model.granularity
Expand All @@ -258,13 +272,13 @@ def calculate_binary_class_model_quality_group_by_timestamp(self):
]
)
else:
dataset_with_group = self.current.current.select(
dataset_with_group = current_df_clean.select(
[
self.current.model.outputs.prediction_proba.name,
self.current.model.target.name,
f.date_format(
f.to_timestamp(
f.date_format(
F.date_format(
F.to_timestamp(
F.date_format(
self.current.model.timestamp.name,
create_time_format(self.current.model.granularity),
)
Expand All @@ -277,12 +291,12 @@ def calculate_binary_class_model_quality_group_by_timestamp(self):
list_of_time_group = (
dataset_with_group.select("time_group")
.distinct()
.orderBy(f.col("time_group").asc())
.orderBy(F.col("time_group").asc())
.rdd.flatMap(lambda x: x)
.collect()
)
array_of_groups = [
dataset_with_group.where(f.col("time_group") == x)
dataset_with_group.where(F.col("time_group") == x)
for x in list_of_time_group
]

Expand All @@ -299,33 +313,37 @@ def calculate_binary_class_model_quality_group_by_timestamp(self):

def calculate_confusion_matrix(self) -> dict[str, float]:
prediction_and_label = (
self.current.current.select(
self.current.current.filter(
is_not_null(self.current.model.outputs.prediction.name)
& is_not_null(self.current.model.target.name)
)
.select(
[
self.current.model.outputs.prediction.name,
self.current.model.target.name,
]
)
.withColumn(
self.current.model.target.name, f.col(self.current.model.target.name)
self.current.model.target.name, F.col(self.current.model.target.name)
)
.orderBy(self.current.model.target.name)
)

tp = prediction_and_label.filter(
(col(self.current.model.outputs.prediction.name) == 1)
& (col(self.current.model.target.name) == 1)
(F.col(self.current.model.outputs.prediction.name) == 1)
& (F.col(self.current.model.target.name) == 1)
).count()
tn = prediction_and_label.filter(
(col(self.current.model.outputs.prediction.name) == 0)
& (col(self.current.model.target.name) == 0)
(F.col(self.current.model.outputs.prediction.name) == 0)
& (F.col(self.current.model.target.name) == 0)
).count()
fp = prediction_and_label.filter(
(col(self.current.model.outputs.prediction.name) == 1)
& (col(self.current.model.target.name) == 0)
(F.col(self.current.model.outputs.prediction.name) == 1)
& (F.col(self.current.model.target.name) == 0)
).count()
fn = prediction_and_label.filter(
(col(self.current.model.outputs.prediction.name) == 0)
& (col(self.current.model.target.name) == 1)
(F.col(self.current.model.outputs.prediction.name) == 0)
& (F.col(self.current.model.target.name) == 1)
).count()

return {
Expand Down
42 changes: 28 additions & 14 deletions spark/jobs/utils/reference_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
BinaryClassificationEvaluator,
MulticlassClassificationEvaluator,
)
from pyspark.sql.functions import col
import pyspark.sql.functions as f
import pyspark.sql.functions as F

from metrics.data_quality_calculator import DataQualityCalculator
from models.data_quality import (
Expand All @@ -16,6 +15,7 @@
BinaryClassDataQuality,
)
from models.reference_dataset import ReferenceDataset
from .spark import is_not_null


class ReferenceMetricsService:
Expand Down Expand Up @@ -73,15 +73,25 @@ def __evaluate_multi_class_classification(
# FIXME use pydantic struct like data quality
def __calc_bc_metrics(self) -> dict[str, float]:
return {
label: self.__evaluate_binary_classification(self.reference.reference, name)
label: self.__evaluate_binary_classification(
self.reference.reference.filter(
is_not_null(self.reference.model.outputs.prediction.name)
& is_not_null(self.reference.model.target.name)
),
name,
)
for (name, label) in self.model_quality_binary_classificator.items()
}

# FIXME use pydantic struct like data quality
def __calc_mc_metrics(self) -> dict[str, float]:
return {
label: self.__evaluate_multi_class_classification(
self.reference.reference, name
self.reference.reference.filter(
is_not_null(self.reference.model.outputs.prediction.name)
& is_not_null(self.reference.model.target.name)
),
name,
)
for (name, label) in self.model_quality_multiclass_classificator.items()
}
Expand All @@ -98,34 +108,38 @@ def calculate_model_quality(self) -> dict[str, float]:
# FIXME use pydantic struct like data quality
def calculate_confusion_matrix(self) -> dict[str, float]:
prediction_and_label = (
self.reference.reference.select(
self.reference.reference.filter(
is_not_null(self.reference.model.outputs.prediction.name)
& is_not_null(self.reference.model.target.name)
)
.select(
[
self.reference.model.outputs.prediction.name,
self.reference.model.target.name,
]
)
.withColumn(
self.reference.model.target.name,
f.col(self.reference.model.target.name),
F.col(self.reference.model.target.name),
)
.orderBy(self.reference.model.target.name)
)

tp = prediction_and_label.filter(
(col(self.reference.model.outputs.prediction.name) == 1)
& (col(self.reference.model.target.name) == 1)
(F.col(self.reference.model.outputs.prediction.name) == 1)
& (F.col(self.reference.model.target.name) == 1)
).count()
tn = prediction_and_label.filter(
(col(self.reference.model.outputs.prediction.name) == 0)
& (col(self.reference.model.target.name) == 0)
(F.col(self.reference.model.outputs.prediction.name) == 0)
& (F.col(self.reference.model.target.name) == 0)
).count()
fp = prediction_and_label.filter(
(col(self.reference.model.outputs.prediction.name) == 1)
& (col(self.reference.model.target.name) == 0)
(F.col(self.reference.model.outputs.prediction.name) == 1)
& (F.col(self.reference.model.target.name) == 0)
).count()
fn = prediction_and_label.filter(
(col(self.reference.model.outputs.prediction.name) == 0)
& (col(self.reference.model.target.name) == 1)
(F.col(self.reference.model.outputs.prediction.name) == 0)
& (F.col(self.reference.model.target.name) == 1)
).count()

return {
Expand Down
Loading

0 comments on commit 02659e7

Please sign in to comment.