In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
from sklearn.model_selection import TimeSeriesSplit
from sklearn.calibration import calibration_curve
from sklearn.metrics import (
    f1_score,
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
    auc,
    precision_recall_curve,
    confusion_matrix
)
from catboost import Pool
import mlflow
from mlflow.types.schema import Schema, ColSpec
import optuna
import shap
shap.initjs()
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt

from var import DATA_OUT, IMAGE_OUT, FORECAST_HOURS_IN_ADVANCE
from src import ML_SERVER_URI, EXPERIMENT_NAME
from src.opt import objective
from src.model import (
    get_or_create_experiment,
    start_crossvalidated_run,
    evaluate_crossvalidated_metrics,
)

# logging level: WARNING (INFO logs are suppressed)
optuna.logging.set_verbosity(optuna.logging.WARNING)

# MLFlow web server URI
mlflow.set_tracking_uri(ML_SERVER_URI)

In [None]:
df = pd.read_pickle(Path(DATA_OUT, 'df_dataset.pickle'))

In [None]:
IS_MIDTRM_MIDACC = False

X = df[
    [
        *[col_ for col_ in df.columns if col_.startswith('ie_')],
        # *[col_ for col_ in df.columns if col_.startswith('il_')],
        *[col_ for col_ in df.columns if col_.startswith('iu_')],
        # *[col_ for col_ in df.columns if col_.startswith('io_')],
        'hf',
        'hf_mav_2h',
        'f_107_adj',
        'hp_30',
        'smr',
        'solar_zenith_angle',
#         'newell',
        'bz',
        'vx',
        'rho',
        *[col_ for col_ in df.columns if col_.startswith('local_warning_')],
        *[col_ for col_ in df.columns if col_.startswith('spectral_contribution_')],
        *[col_ for col_ in df.columns if col_.startswith('azimuth_')],
        *[col_ for col_ in df.columns if col_.startswith('velocity_')],
    ]
].copy()

if IS_MIDTRM_MIDACC:
    target = f'tid_within_{2*FORECAST_HOURS_IN_ADVANCE}h_to_{FORECAST_HOURS_IN_ADVANCE}h'
else:
    target = f'tid_within_{FORECAST_HOURS_IN_ADVANCE}h'

y = df[target].copy()

In [None]:
cat_features = [
    'ie_variation',
    'iu_variation',
    *[col_ for col_ in df.columns if col_.startswith('local_warning_')],
]

static_params = {
    "eval_metric": 'F1',
    "random_seed": 42,
    # "auto_class_weights": "SqrtBalanced", # "Balanced",
    "cat_features": cat_features,
    "od_type": "Iter",
    "use_best_model": True,
    "has_time": True,
    "od_wait": 200,
}

In [None]:
# n_days_for_testing = 365
# ts_cv = TimeSeriesSplit(n_splits=5, test_size=n_days_for_testing*24*2)

ts_cv = TimeSeriesSplit(n_splits=5)

In [None]:
# fig, axs = plt.subplots(5, 1, figsize=(40, 20), sharex=True)
# 
# for fold, (train_idx, val_idx) in enumerate(ts_cv.split(X)):
#     train = X.iloc[train_idx]
#     test = X.iloc[val_idx]
#     train['hf'].plot(
#         ax=axs[fold],
#         title=f'Train/Test split - fold {fold + 1}',
#     )
#     test['hf'].plot(ax=axs[fold])
#     axs[fold].axvline(test.index.min(), color='black', ls='--')
#     axs[fold].set_ylabel('HF index')
# 
# plt.savefig(
#     Path(IMAGE_OUT, f'train_test_split.png', dpi=500, bbox_inches='tight')
# )
# plt.show()

## Optuna (hyper-params optimisation)

In [None]:
obj = lambda trial: objective(
    trial,
    X=X,
    y=y,
    cv=ts_cv,
    params=static_params,
)

study = optuna.create_study(study_name='catboost_clf', direction='maximize')
study.optimize(obj, n_trials=10, show_progress_bar=True)

trial = study.best_trial

In [None]:
print("Best trial:")
print(f"  F1: {trial.value:.3f}")

print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

In [None]:
# optuna.visualization.plot_optimization_history(study)

In [None]:
# optuna.visualization.plot_param_importances(study)

In [None]:
# Fine-tuned hyperparams
best_params = study.best_params
static_params.update(best_params)

## Fine-tuned and/or cross-validated model

In [None]:
mlflow.set_experiment(
    experiment_id=get_or_create_experiment(EXPERIMENT_NAME)
);

### Pre-trained model (retrieved by `run_id`)

In [None]:
cat_model, (f1s, prs, rcs) = start_crossvalidated_run(
    X=X,
    y=y,
    time_series_cross_validator=ts_cv,
    run_id='ed451d15e8094aa593f44d07bc77696d',
)

### Train a model from scratch, usually after h-params optimisation

**[Model signatures](https://www.mlflow.org/docs/latest/models.html#model-signature)** define what the model expects (input, output and parameters) and enforce it later in deployment.

Signatures are fetched by the Tracking UI and Model Registry UI to display model inputs, outputs and params; they are also utilized by MLflow model deployment tools to validate inference inputs according to the model’s assigned signature

In [None]:
input_schema = Schema(
    [ColSpec("double", col_) for col_ in X.columns]
)

output_schema = Schema(
    [ColSpec("long", y.name)]
)

signature = mlflow.models.ModelSignature(inputs=input_schema, outputs=output_schema)

cat_model, (f1s, prs, rcs) = start_crossvalidated_run(
    X=X,
    y=y,
    time_series_cross_validator=ts_cv,
    model_params=static_params,
    model_signature=signature,
)

In [None]:
weights = [tr_idx.shape[0] / X.shape[0] for tr_idx, _ in ts_cv.split(X)]

evaluate_crossvalidated_metrics(
    metrics={
        'F1-score': f1s,
        'Precision': prs,
        'Recall': rcs,
    },
    weights=weights,
)

In [None]:
evaluate_crossvalidated_metrics(
    metrics={
        'F1-score': f1s,
        'Precision': prs,
        'Recall': rcs,
    },
    weights=None,
)

In [None]:
train_idx, test_idx = [(tr_idx, ts_idx) for (tr_idx, ts_idx) in ts_cv.split(X)][-1]

X_train, X_test = X.iloc[train_idx], X.iloc[test_idx]
y_train, y_test = y.iloc[train_idx], y.iloc[test_idx]

y_pred = cat_model.predict(X_test)

## SHAP

In [None]:
shap_values = cat_model.get_feature_importance(
    Pool(X_test, label=y_test, cat_features=cat_features),
    type="ShapValues",
)

shap_values = shap_values[:,:-1]

In [None]:
# df_shap = pd.DataFrame(
#     (
#         zip(
#             X_train.columns[np.argsort(np.abs(shap_values).mean(0))][::-1],
#             -np.sort(-np.abs(shap_values).mean(0))
#         )
#     ),
#     columns=["feature", "importance"],
# )
# 
# df_shap.to_pickle(Path(DATA_OUT, 'df_feat_imp.pickle'))

In [None]:
shap.summary_plot(
    shap_values,
    X_test,
    max_display=15,
#    show=False,
)

# plt.savefig(
#     Path(IMAGE_OUT, 'shap_summary.png'),
#     dpi=300,
# )

In [None]:
shap.dependence_plot(
    'newell',
    shap_values,
    X_test,
    interaction_index='iu_fix',
)

In [None]:
corr = df.select_dtypes(exclude=['object'])[
    [
        'ie_fix',
        'iu_fix',
        'hf',
        'hp_30',
        'smr',
        'newell',
        'bz',
        'vx',
        'rho',
        'solar_zenith_angle',
        'tid_within_3h',
    ]
].corr()

corr.style.background_gradient(cmap='coolwarm')

In [None]:
from src.io import read_time_series
from var import DATA_IN

df_tid = read_time_series(
    Path(DATA_IN, 'TID_catalog.csv'),
    column_names=[
        'duration',
        'period',
        'amplitude',
        'spectral_contribution',
        'velocity',
        'azimuth',
        'quality_index',
        'datetime',
    ],
)

In [None]:
explainer = shap.TreeExplainer(cat_model)
shap_values_ = explainer.shap_values(X_test)

In [None]:
row = X_test.index.get_loc('2022-01-04 16:00')

shap.force_plot(
    explainer.expected_value,
    shap_values[row,:],
    X_test.iloc[row,:],
    link='logit',
)

In [None]:
y_test.iloc[row], y_pred[row]

In [None]:
y_pred[y_pred==1].shape[0] / y_pred.shape[0]

In [None]:
shap.plots.bar(
    explainer(X_test),
#    show=False,
    max_display=15,
)

# plt.savefig(
#     Path(IMAGE_OUT, 'shap_bar_plot.png'),
#     dpi=300,
#     bbox_inches='tight',
# )

## Confusion matrix

In [None]:
import seaborn as sns

conf_matrix = confusion_matrix(y_test, y_pred, normalize='true')

plt.figure(figsize=(8, 8))

sns.set(font_scale=1.5)

sns.heatmap(
    conf_matrix,
    annot=True,
    fmt='.2f',
    cmap='Blues',
    cbar=False,
    xticklabels=['TID not predicted', 'TID predicted'],
    yticklabels=['TID doesn\'t occur', 'TID occurs'],
)

# plt.savefig(
#     Path(IMAGE_OUT, f'confusion_matrix.png', dpi=500, bbox_inches='tight')
# )

plt.show()

## Evaluation of classification

In [None]:
df_eval = X_test.copy(deep=True)
df_eval['true'] = y_test
df_eval['pred'] = cat_model.predict(X_test)
df_eval['pred_proba'] = cat_model.predict_proba(X_test)[:,1]

In [None]:
f1_score(
    df_eval['true'],
    df_eval['pred'],
).round(3)

### ROC curve

In [None]:
fpr, tpr, thresholds = roc_curve(
    df_eval['true'],
    df_eval['pred_proba'],
)

In [None]:
roc_auc = roc_auc_score(
    df_eval['true'],
    df_eval['pred_proba'],
)

In [None]:
fig = px.scatter(x=fpr, y=tpr)

fig.update_layout(
    height=700,
    width=800,
    autosize=False,
    shapes=[
        dict(
            type='line',
            x0=0,
            y0=0,
            x1=1,
            y1=1,
            line=dict(color='navy', width=2, dash='dash'),
        )
    ],
    title=f'ROC Curve (ROC-AUC: <b>{roc_auc:.2f}</b>)',
    xaxis=dict(title='False Positive Rate'),
    yaxis=dict(title='True Positive Rate'),
    template='ggplot2',
)

# fig.write_html(
#     Path(IMAGE_OUT,f'plot_roc_curve.html')
# )

fig.show()

### PR curve

In [None]:
p, r, t = precision_recall_curve(
    df_eval['true'],
    df_eval['pred_proba'],
    drop_intermediate=True
)

In [None]:
pr_auc = auc(r, p)

In [None]:
f1_scores = 2 * (p*r) / (p+r)
thr_f1_max = t[np.argmax(f1_scores)]

print(
    f'{np.round(thr_f1_max, 3)} is the threshold that maximises F1-score to {np.round(f1_scores[np.argmax(f1_scores)], 3)}'
)

In [None]:
np.where(np.logical_and(p>=0.80, p<0.8001))

In [None]:
idx = 5872
thr_p_80 = t[idx]

print(
    f'{thr_p_80.round(3)} is the threshold that gives a precision of {p[idx].round(3)} (recall: {r[idx].round(3)} | F1-score: {f1_scores[idx].round(3)})'
)

In [None]:
np.where(np.logical_and(r>=0.60, r<0.6005))

In [None]:
idx = 3424
thr_r_60 = t[idx]

print(
    f'{thr_r_60.round(3)} is the threshold that gives a recall of {r[idx].round(3)} (precision: {p[idx].round(3)} | F1-score: {f1_scores[idx].round(3)})'
)

In [None]:
thr_f1_max.round(3), thr_p_80.round(3), thr_r_60.round(3)

In [None]:
df_eval['pred_f1_max'] = np.where(
    df_eval['pred_proba'].gt(thr_f1_max),
    1,
    0,
)

In [None]:
df_eval['pred_p_80'] = np.where(
    df_eval['pred_proba'].gt(thr_p_80),
    1,
    0,
)

In [None]:
df_eval['pred_r_60'] = np.where(
    df_eval['pred_proba'].gt(thr_r_60),
    1,
    0,
)

In [None]:
# df_eval.to_pickle(Path(DATA_OUT, 'df_eval.pickle'))

In [None]:
fig = px.scatter(x=r, y=p)

fig.add_shape(
    type='line',
    x0=0,
    y0=1,
    x1=1,
    y1=0,
    line=dict(color='navy', width=2, dash='dash'),
)

fig.add_shape(
    type='line',
    x0=r[idx],
    y0=0,
    x1=r[idx],
    y1=p[idx],
    line=dict(color='red', width=2, dash='dash'),
)

fig.add_shape(
    type='line',
    x0=0,
    y0=p[idx],
    x1=r[idx],
    y1=p[idx],
    line=dict(color='red', width=2, dash='dash'),
)

DOT_SIZE = 15
fig.add_trace(
    go.Scatter(
        x=[r[idx]],
        y=[p[idx]],
        mode='markers',
        marker=dict(color='red', size=DOT_SIZE),
        showlegend=False,
    )
)

fig.update_layout(
    height=700,
    width=800,
    autosize=False,
    title=f'PR Curve (PR-AUC: <b>{pr_auc:.2f}</b>)',
    xaxis=dict(title='Recall'),
    yaxis=dict(title='Precision'),
    template='ggplot2',
)

# fig.write_html(
#     Path(IMAGE_OUT,f'plot_pr_curve.html')
# )

fig.show()

## Calibration curve

In [None]:
prob_true, prob_pred = calibration_curve(
    df_eval['true'],
    df_eval['pred_proba'],
    n_bins=10,
)

In [None]:
fig = px.line(
    x=prob_pred,
    y=prob_true,
    markers=True,
)

fig.add_shape(
    type='line',
    x0=0,
    y0=0,
    x1=1,
    y1=1,
    line=dict(color='red', width=2, dash='dash'),
    name='Perfectly calibrated',
    showlegend=True,
)

fig.add_trace(
    go.Histogram(
        x=df_eval['pred_proba'],
        yaxis='y2',
        opacity=0.3,
        showlegend=False,
        nbinsx=25,
    )
)

fig.update_layout(
    xaxis_title='Mean predicted probability',
    yaxis_title='Fraction of positives',
    yaxis2=dict(
        title='Count of samples',
        overlaying='y',
        side='right'
    )
)

fig.show()

In [None]:
from sklearn.metrics import brier_score_loss

In [None]:
brier_score_loss(
    y_test,
    cat_model.predict_proba(X_test)[:,1]
).round(3)

## Plot features vs target

In [None]:
# from src.preprocess import resample_time_series
# 
# df_tid_30 = resample_time_series(df_tid, 'median')

In [None]:
period = '2022-01'

df_plt = df_eval.loc[
    f'{period}',
    ['hf','iu_mav_3h','smr','true','pred']
]

n_cols = len(df_plt.columns)

fig = make_subplots(
    rows=n_cols,
    cols=1,
    shared_xaxes=True,
    vertical_spacing=0.04,
    subplot_titles=df_plt.columns,
)

for i, col in enumerate(df_plt.columns, start=1):
    fig.add_trace(
        go.Scatter(
            x=df_plt[col].index,
            y=df_plt[col].values,
            name=col,
        ),
        row=i,
        col=1,
    )
    
fig.update_layout(
    template='plotly_white',
    height=800,
    width=1_000,
    autosize=False,
    title=f'Period: <b>{period}</b>',
)

fig.show()

# fig.write_html(
#     Path(IMAGE_OUT,f'plot_features_target_{period}.html')
# )