Skip to content

Commit

Permalink
feat: add residuals to regression dto (#106)
Browse files Browse the repository at this point in the history
  • Loading branch information
dtria91 authored Jul 15, 2024
1 parent 2d4c024 commit 6a662aa
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 10 deletions.
28 changes: 20 additions & 8 deletions api/app/models/metrics/model_quality_dto.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,25 @@ class CurrentMultiClassificationModelQuality(BaseModel):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class RegressionMetricsBase(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
mse: Optional[float] = None
variance: Optional[float] = None
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None
class KsMetrics(BaseModel):
p_value: Optional[float] = None
statistic: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class Histogram(BaseModel):
buckets: List[float]
values: Optional[List[int]] = None


class ResidualsMetrics(BaseModel):
ks: KsMetrics
correlation_coefficient: Optional[float] = None
histogram: Histogram
standardized_residuals: List[float]
predictions: List[float]
targets: List[float]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand All @@ -145,6 +156,7 @@ class BaseRegressionMetrics(BaseModel):
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None
residuals: ResidualsMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down
14 changes: 14 additions & 0 deletions api/tests/commons/db_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,20 @@ def get_sample_current_dataset(
'rmse': 202.23194752188695,
'adj_r2': 0.9116805380966796,
'variance': 0.23,
'residuals': {
'ks': {
'p_value': 0.01,
'statistic': 0.4,
},
'histogram': {
'values': [1, 2, 3],
'buckets': [-3.2, -1, 2.2],
},
'correlation_coefficient': 0.01,
'standardized_residuals': [0.02, 0.03],
'targets': [1, 2.2, 3],
'predictions': [1.3, 2, 4.5],
},
}

grouped_regression_model_quality_dict = {
Expand Down
24 changes: 24 additions & 0 deletions sdk/radicalbit_platform_sdk/models/dataset_model_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,29 @@ class CurrentMultiClassificationModelQuality(ModelQuality):
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class KsMetrics(BaseModel):
p_value: Optional[float] = None
statistic: Optional[float] = None

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class Histogram(BaseModel):
buckets: List[float]
values: Optional[List[int]] = None


class ResidualsMetrics(BaseModel):
ks: KsMetrics
correlation_coefficient: Optional[float] = None
histogram: Histogram
standardized_residuals: List[float]
predictions: List[float]
targets: List[float]

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)


class BaseRegressionMetrics(BaseModel):
r2: Optional[float] = None
mae: Optional[float] = None
Expand All @@ -132,6 +155,7 @@ class BaseRegressionMetrics(BaseModel):
mape: Optional[float] = None
rmse: Optional[float] = None
adj_r2: Optional[float] = None
residuals: ResidualsMetrics

model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)

Expand Down
25 changes: 24 additions & 1 deletion sdk/tests/apis/model_current_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,9 @@ def test_regression_model_quality_ok(self):
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
p_value = 0.2
statistic = 0.4
correlation_coefficient = 0.2
model_current_dataset = ModelCurrentDataset(
base_url,
model_id,
Expand Down Expand Up @@ -986,7 +989,21 @@ def test_regression_model_quality_ok(self):
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
"adjR2": {adj_r2},
"residuals": {{
"ks": {{
"p_value": {p_value},
"statistic": {statistic}
}},
"histogram": {{
"values": [1, 2, 3],
"buckets": [-3.2, -1, 2.2]
}},
"correlationCoefficient": {correlation_coefficient},
"standardizedResiduals": [0.02, 0.03],
"targets": [1, 2.2, 3],
"predictions": [1.3, 2, 4.5]
}}
}},
"grouped_metrics": {{
"r2": [
Expand Down Expand Up @@ -1032,6 +1049,12 @@ def test_regression_model_quality_ok(self):
assert metrics.global_metrics.mape == mape
assert metrics.global_metrics.rmse == rmse
assert metrics.global_metrics.adj_r2 == adj_r2
assert (
metrics.global_metrics.residuals.correlation_coefficient
== correlation_coefficient
)
assert metrics.global_metrics.residuals.ks.p_value == p_value
assert metrics.global_metrics.residuals.ks.statistic == statistic
assert metrics.grouped_metrics.r2[0].value == r2
assert metrics.grouped_metrics.mae[0].value == mae
assert metrics.grouped_metrics.mse[0].value == mse
Expand Down
22 changes: 21 additions & 1 deletion sdk/tests/apis/model_reference_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def test_regression_model_metrics_ok(self):
mape = 35.19
rmse = 202.23
adj_r2 = 0.91
p_value = 0.2
statistic = 0.4
correlation_coefficient = 0.2
model_reference_dataset = ModelReferenceDataset(
base_url,
model_id,
Expand All @@ -375,7 +378,21 @@ def test_regression_model_metrics_ok(self):
"variance": {variance},
"mape": {mape},
"rmse": {rmse},
"adjR2": {adj_r2}
"adjR2": {adj_r2},
"residuals": {{
"ks": {{
"p_value": {p_value},
"statistic": {statistic}
}},
"histogram": {{
"values": [1, 2, 3],
"buckets": [-3.2, -1, 2.2]
}},
"correlationCoefficient": {correlation_coefficient},
"standardizedResiduals": [0.02, 0.03],
"targets": [1, 2.2, 3],
"predictions": [1.3, 2, 4.5]
}}
}}
}}""",
)
Expand All @@ -390,6 +407,9 @@ def test_regression_model_metrics_ok(self):
assert metrics.mape == mape
assert metrics.rmse == rmse
assert metrics.adj_r2 == adj_r2
assert metrics.residuals.correlation_coefficient == correlation_coefficient
assert metrics.residuals.ks.p_value == p_value
assert metrics.residuals.ks.statistic == statistic
assert model_reference_dataset.status() == JobStatus.SUCCEEDED

@responses.activate
Expand Down

0 comments on commit 6a662aa

Please sign in to comment.