Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] incorporate typing for Conformal class and conformal_predict method #1104

Merged
merged 6 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 7 additions & 5 deletions neuralprophet/conformal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import dataclass
from typing import List, Optional

import matplotlib
import numpy as np
import pandas as pd

from neuralprophet.plot_forecast_matplotlib import plot_nonconformity_scores
Expand Down Expand Up @@ -28,9 +30,9 @@ class Conformal:

alpha: float
method: str
quantiles: list = None
quantiles: Optional[List[float]] = None

def predict(self, df, df_cal):
def predict(self, df: pd.DataFrame, df_cal: pd.DataFrame) -> pd.DataFrame:
"""Apply a given conformal prediction technique to get the uncertainty prediction intervals (or q-hat) for test dataframe.

Parameters
Expand Down Expand Up @@ -67,7 +69,7 @@ def predict(self, df, df_cal):

return df

def _get_nonconformity_scores(self, df_cal):
def _get_nonconformity_scores(self, df_cal: pd.DataFrame) -> np.ndarray:
"""Get the nonconformity scores using the given conformal prediction technique.

Parameters
Expand Down Expand Up @@ -109,7 +111,7 @@ def _get_nonconformity_scores(self, df_cal):

return noncon_scores

def _get_q_hat(self, df_cal):
def _get_q_hat(self, df_cal: pd.DataFrame) -> float:
"""Get the q_hat that is derived from the nonconformity scores.

Parameters
Expand All @@ -129,7 +131,7 @@ def _get_q_hat(self, df_cal):

return q_hat

def plot(self, plotting_backend):
def plot(self, plotting_backend: str):
"""Apply a given conformal prediction technique to get the uncertainty prediction intervals (or q-hats).

Parameters
Expand Down
10 changes: 9 additions & 1 deletion neuralprophet/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3104,7 +3104,15 @@ def _reshape_raw_predictions_to_forecst_df(self, df, predicted, components):
df_forecast = pd.concat([df_forecast, yhat_df], axis=1, ignore_index=False)
return df_forecast

def conformal_predict(self, df, calibration_df, alpha, method="naive", plotting_backend="default", **kwargs):
def conformal_predict(
self,
df: pd.DataFrame,
calibration_df: pd.DataFrame,
alpha: float,
method: str = "naive",
plotting_backend: str = "default",
**kwargs,
) -> pd.DataFrame:
"""Apply a given conformal prediction technique to get the uncertainty prediction intervals (or q-hats). Then predict.

Parameters
Expand Down