Skip to content

Commit

Permalink
feat: add statistics in current dataset for multiclass (#53)
Browse files Browse the repository at this point in the history
* feat: add statistics for current multiclass

* feat: add test and handle multiclass in job
  • Loading branch information
rivamarco committed Jul 1, 2024
1 parent e0e68d2 commit 3881f97
Show file tree
Hide file tree
Showing 10 changed files with 280 additions and 21 deletions.
44 changes: 27 additions & 17 deletions spark/jobs/current_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from models.current_dataset import CurrentDataset
from models.reference_dataset import ReferenceDataset
from utils.current_binary import CurrentMetricsService
from utils.models import JobStatus, ModelOut
from utils.models import JobStatus, ModelOut, ModelType
from utils.db import update_job_status, write_to_db

from pyspark.sql import SparkSession
Expand Down Expand Up @@ -49,23 +49,33 @@ def main(
raw_reference = spark_session.read.csv(reference_dataset_path, header=True)
reference_dataset = ReferenceDataset(model=model, raw_dataframe=raw_reference)

metrics_service = CurrentMetricsService(
spark_session, current_dataset.current, reference_dataset.reference, model=model
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()
model_quality = metrics_service.calculate_model_quality_with_group_by_timestamp()
drift = metrics_service.calculate_drift()
complete_record = {"UUID": str(uuid.uuid4()), "CURRENT_UUID": current_uuid}

# TODO put needed fields here
complete_record = {
"UUID": str(uuid.uuid4()),
"CURRENT_UUID": current_uuid,
"STATISTICS": orjson.dumps(statistics).decode("utf-8"),
"DATA_QUALITY": data_quality.model_dump_json(serialize_as_any=True),
"MODEL_QUALITY": orjson.dumps(model_quality).decode("utf-8"),
"DRIFT": orjson.dumps(drift).decode("utf-8"),
}
match model.model_type:
case ModelType.BINARY:
metrics_service = CurrentMetricsService(
spark_session,
current_dataset.current,
reference_dataset.reference,
model=model,
)
statistics = calculate_statistics_current(current_dataset)
data_quality = metrics_service.calculate_data_quality()
model_quality = (
metrics_service.calculate_model_quality_with_group_by_timestamp()
)
drift = metrics_service.calculate_drift()
complete_record["MODEL_QUALITY"] = orjson.dumps(model_quality).decode(
"utf-8"
)
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")
complete_record["DATA_QUALITY"] = data_quality.model_dump_json(
serialize_as_any=True
)
complete_record["DRIFT"] = orjson.dumps(drift).decode("utf-8")
case ModelType.MULTI_CLASS:
statistics = calculate_statistics_current(current_dataset)
complete_record["STATISTICS"] = orjson.dumps(statistics).decode("utf-8")

schema = StructType(
[
Expand Down
2 changes: 1 addition & 1 deletion spark/tests/models/reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
@pytest.fixture()
def dataset_target_string(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/dataset_target_string.csv",
f"{test_data_dir}/reference/multiclass/reference/dataset_target_string.csv",
header=True,
)

Expand Down
215 changes: 215 additions & 0 deletions spark/tests/multiclass_current_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
import datetime
import uuid

import pytest

from jobs.metrics.statistics import calculate_statistics_current
from jobs.utils.models import (
ModelOut,
ModelType,
DataType,
OutputType,
ColumnDefinition,
SupportedTypes,
Granularity,
)
from models.current_dataset import CurrentDataset
from tests.utils.pytest_utils import my_approx


@pytest.fixture()
def dataset_target_int(spark_fixture, test_data_dir):
yield (
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/current/dataset_target_int.csv",
header=True,
),
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/reference/dataset_target_int.csv",
header=True,
),
)


@pytest.fixture()
def dataset_target_string(spark_fixture, test_data_dir):
yield (
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/current/dataset_target_string.csv",
header=True,
),
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/reference/dataset_target_string.csv",
header=True,
),
)


@pytest.fixture()
def dataset_perfect_classes(spark_fixture, test_data_dir):
yield (
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/current/dataset_perfect_classes.csv",
header=True,
),
spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/reference/dataset_perfect_classes.csv",
header=True,
),
)


def test_calculation_dataset_target_int(spark_fixture, dataset_target_int):
output = OutputType(
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.int),
prediction_proba=None,
output=[ColumnDefinition(name="prediction", type=SupportedTypes.int)],
)
target = ColumnDefinition(name="target", type=SupportedTypes.int)
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="cat1", type=SupportedTypes.string),
ColumnDefinition(name="cat2", type=SupportedTypes.string),
ColumnDefinition(name="num1", type=SupportedTypes.float),
ColumnDefinition(name="num2", type=SupportedTypes.float),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="model",
description="description",
model_type=ModelType.MULTI_CLASS,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

current_dataframe, reference_dataframe = dataset_target_int
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe)

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
{
"categorical": 2,
"datetime": 1,
"duplicate_rows": 0,
"duplicate_rows_perc": 0.0,
"missing_cells": 3,
"missing_cells_perc": 4.285714285714286,
"n_observations": 10,
"n_variables": 7,
"numeric": 4,
}
)


def test_calculation_dataset_target_string(spark_fixture, dataset_target_string):
output = OutputType(
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.string),
prediction_proba=None,
output=[ColumnDefinition(name="prediction", type=SupportedTypes.string)],
)
target = ColumnDefinition(name="target", type=SupportedTypes.string)
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="cat1", type=SupportedTypes.string),
ColumnDefinition(name="cat2", type=SupportedTypes.string),
ColumnDefinition(name="num1", type=SupportedTypes.float),
ColumnDefinition(name="num2", type=SupportedTypes.float),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="model",
description="description",
model_type=ModelType.MULTI_CLASS,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

current_dataframe, reference_dataframe = dataset_target_string
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe)

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
{
"categorical": 4,
"datetime": 1,
"duplicate_rows": 0,
"duplicate_rows_perc": 0.0,
"missing_cells": 3,
"missing_cells_perc": 4.285714285714286,
"n_observations": 10,
"n_variables": 7,
"numeric": 2,
}
)


def test_calculation_dataset_perfect_classes(spark_fixture, dataset_perfect_classes):
output = OutputType(
prediction=ColumnDefinition(name="prediction", type=SupportedTypes.string),
prediction_proba=None,
output=[ColumnDefinition(name="prediction", type=SupportedTypes.string)],
)
target = ColumnDefinition(name="target", type=SupportedTypes.string)
timestamp = ColumnDefinition(name="datetime", type=SupportedTypes.datetime)
granularity = Granularity.HOUR
features = [
ColumnDefinition(name="cat1", type=SupportedTypes.string),
ColumnDefinition(name="cat2", type=SupportedTypes.string),
ColumnDefinition(name="num1", type=SupportedTypes.float),
ColumnDefinition(name="num2", type=SupportedTypes.float),
]
model = ModelOut(
uuid=uuid.uuid4(),
name="model",
description="description",
model_type=ModelType.MULTI_CLASS,
data_type=DataType.TABULAR,
timestamp=timestamp,
granularity=granularity,
outputs=output,
target=target,
features=features,
frameworks="framework",
algorithm="algorithm",
created_at=str(datetime.datetime.now()),
updated_at=str(datetime.datetime.now()),
)

current_dataframe, reference_dataframe = dataset_perfect_classes
current_dataset = CurrentDataset(model=model, raw_dataframe=current_dataframe)

stats = calculate_statistics_current(current_dataset)

assert stats == my_approx(
{
"categorical": 4,
"datetime": 1,
"duplicate_rows": 0,
"duplicate_rows_perc": 0.0,
"missing_cells": 3,
"missing_cells_perc": 4.285714285714286,
"n_observations": 10,
"n_variables": 7,
"numeric": 2,
}
)
7 changes: 4 additions & 3 deletions spark/tests/multiclass_reference_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,23 @@
@pytest.fixture()
def dataset_target_int(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/dataset_target_int.csv", header=True
f"{test_data_dir}/reference/multiclass/reference/dataset_target_int.csv",
header=True,
)


@pytest.fixture()
def dataset_target_string(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/dataset_target_string.csv",
f"{test_data_dir}/reference/multiclass/reference/dataset_target_string.csv",
header=True,
)


@pytest.fixture()
def dataset_perfect_classes(spark_fixture, test_data_dir):
yield spark_fixture.read.csv(
f"{test_data_dir}/reference/multiclass/dataset_perfect_classes.csv",
f"{test_data_dir}/reference/multiclass/reference/dataset_perfect_classes.csv",
header=True,
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cat1,cat2,num1,num2,prediction,target,datetime
A,X,1.0,1.4,HEALTHY,HEALTHY,2024-06-16 00:01:00-05:00
B,X,1.5,100.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:02:00-05:00
A,Y,3.0,123.0,HEALTHY,HEALTHY,2024-06-16 00:03:00-05:00
B,X,0.5,,UNKNOWN,UNKNOWN,2024-06-16 00:04:00-05:00
B,X,0.5,,ORPHAN,ORPHAN,2024-06-16 00:05:00-05:00
B,X,,200.0,HEALTHY,HEALTHY,2024-06-16 00:06:00-05:00
C,X,1.0,300.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:07:00-05:00
A,X,1.0,499.0,UNKNOWN,UNKNOWN,2024-06-16 00:08:00-05:00
A,X,1.0,499.0,HEALTHY,HEALTHY,2024-06-16 00:09:00-05:00
A,X,1.0,499.0,ORPHAN,ORPHAN,2024-06-16 00:10:00-05:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cat1,cat2,num1,num2,prediction,target,datetime
A,X,1.0,1.4,1,1,2024-06-16 00:01:00-05:00
B,X,1.5,100.0,0,0,2024-06-16 00:02:00-05:00
A,Y,3.0,123.0,1,1,2024-06-16 00:03:00-05:00
B,X,0.5,,2,0,2024-06-16 00:04:00-05:00
B,X,0.5,,3,2,2024-06-16 00:05:00-05:00
B,X,,200.0,1,3,2024-06-16 00:06:00-05:00
C,X,1.0,300.0,0,0,2024-06-16 00:07:00-05:00
A,X,1.0,499.0,2,2,2024-06-16 00:08:00-05:00
A,X,1.0,499.0,1,1,2024-06-16 00:09:00-05:00
A,X,1.0,499.0,3,2,2024-06-16 00:10:00-05:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
cat1,cat2,num1,num2,prediction,target,datetime
A,X,1.0,1.4,HEALTHY,HEALTHY,2024-06-16 00:01:00-05:00
B,X,1.5,100.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:02:00-05:00
A,Y,3.0,123.0,HEALTHY,HEALTHY,2024-06-16 00:03:00-05:00
B,X,0.5,,UNKNOWN,UNHEALTHY,2024-06-16 00:04:00-05:00
B,X,0.5,,ORPHAN,UNKNOWN,2024-06-16 00:05:00-05:00
B,X,,200.0,HEALTHY,ORPHAN,2024-06-16 00:06:00-05:00
C,X,1.0,300.0,UNHEALTHY,UNHEALTHY,2024-06-16 00:07:00-05:00
A,X,1.0,499.0,UNKNOWN,UNKNOWN,2024-06-16 00:08:00-05:00
A,X,1.0,499.0,HEALTHY,HEALTHY,2024-06-16 00:09:00-05:00
A,X,1.0,499.0,ORPHAN,UNKNOWN,2024-06-16 00:10:00-05:00

0 comments on commit 3881f97

Please sign in to comment.