In [None]:
import numpy as np
from sklearn.metrics import (
    mean_absolute_error,
    mean_absolute_percentage_error,
    mean_squared_error
)
import matplotlib.pyplot as plt
import pandas as pd
import pickle

from sssd.utils.visual import (
    plot_data,
    picp_metic,
    nlpd_metric
)

from properscoring import crps_gaussian

In [None]:
with open("path_to_inference_ouput.pkl", "rb") as input_file:
    d = pickle.load(input_file)

Plot some data

In [None]:
mode = "rbm"
mr = 0.1
N = len(d[mode][mr]["imputation"])

original = d["original"]
observed = original.copy()
evaluated = original.copy()

missing_mask = d[mode][mr]["missing_mask"]
observed[missing_mask] = np.nan
evaluated[~missing_mask] = np.nan

missing_mask = np.array([missing_mask] * N)

imputation = d[mode][mr]["imputation"]
imputation[~missing_mask] = np.nan

Calculate metrics

In [None]:
modes = ["mar", "bom", "rbm", "tsf"]
mrs = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

mae = pd.DataFrame(columns = modes, index = mrs)
mape = pd.DataFrame(columns = modes, index = mrs)
rmse = pd.DataFrame(columns = modes, index = mrs)
picp = pd.DataFrame(columns = modes, index = mrs)
crps = pd.DataFrame(columns = modes, index = mrs)
nlpd = pd.DataFrame(columns = modes, index = mrs)

original = d["original"]
for mode in modes:
    for mr in mrs:
        missing_mask = d[mode][mr]["missing_mask"]
        compare_mask = ~np.isnan(original) & missing_mask
        missing_mask = np.array([missing_mask] * 5)

        imputation = d[mode][mr]["imputation"]
        imputation[~missing_mask] = np.nan

        mae.loc[mr, mode] = mean_absolute_error(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask]
        )

        w = (original[compare_mask]!=0).astype(float)
        w /= w.sum()
        mape.loc[mr, mode] = mean_absolute_percentage_error(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask],
            sample_weight=w
        )

        rmse.loc[mr, mode] = np.sqrt(mean_squared_error(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask]
        ))

        picp.loc[mr, mode] = picp_metic(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask],
            np.std(imputation, axis=0)[compare_mask]
        )
        
        crps.loc[mr, mode] = np.mean(crps_gaussian(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask],
            np.std(imputation, axis=0)[compare_mask]
        ))

        nlpd.loc[mr, mode] = nlpd_metric(
            original[compare_mask],
            np.mean(imputation, axis=0)[compare_mask],
            np.std(imputation, axis=0)[compare_mask]
        )

Plot metrics

In [None]:
fig, axes = plt.subplots(
    nrows=3, ncols=2, figsize=(20, 20)
)
labels = {
    "mar": "Missing at random",
    "bom": "Blackout missing",
    "rbm": "Random block missing",
    "tsf": "Forecasting"
}
for mode in modes:
    axes[0][0].plot(mrs, mae[mode], label=labels[mode])
    axes[0][0].set_xlabel("Missing Rate")
    axes[0][0].set_ylabel("Mean Absolute Error")
    axes[0][0].legend()

    axes[0][1].plot(mrs, mape[mode], label=labels[mode])
    axes[0][1].set_xlabel("Missing Rate")
    axes[0][1].set_ylabel("Mean Absolute Percentage Error")
    axes[0][1].legend()

    axes[1][0].plot(mrs, rmse[mode], label=labels[mode])
    axes[1][0].set_xlabel("Missing Rate")
    axes[1][0].set_ylabel("Root Mean Squared Error")
    axes[1][0].legend()

    axes[1][1].plot(mrs, picp[mode], label=labels[mode])
    axes[1][1].set_xlabel("Missing Rate")
    axes[1][1].set_ylabel("Prediction Interval Coverage Probability")
    axes[1][1].legend()

    axes[2][0].plot(mrs, crps[mode], label=labels[mode])
    axes[2][0].set_xlabel("Missing Rate")
    axes[2][0].set_ylabel("Mean Continious Ranked Probability Score")
    axes[2][0].legend()

    axes[2][1].plot(mrs, nlpd[mode], label=labels[mode])
    axes[2][1].set_xlabel("Missing Rate")
    axes[2][1].set_ylabel("Mean Negative Log Probabiliti Density")
    axes[2][1].legend()