In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

## Reads the csv file 
It contains the metrics for several methods and datasets (obtained with format_metric notebook)

In [None]:
da = pd.read_csv("metrics.csv")

da.head

## Computes a mean value for metrics across folds

In [None]:
db = da.groupby(["Model", "Dataset", "Division type", "Metric", "Mode"]).agg(mean_value=("Value", "mean")).reset_index()

db.head

## Text formating for output

In [None]:
mapping_dict = {"ae": "AE",
               "brits": "BRITS",
               "locf": "LOCF",
               "mrnn": "mRNN",
               "rf": "RF",
               "saits": "SAITS",
               "svm": "SVM",
               "transformer": "Transformer",
               "unet": "U-Net",
               "xgboost": "XGBoost"}
db["Model"] = db["Model"].replace(mapping_dict)

mapping_dict_mode = {"block.20": "Block 20",
                    "block.100": "Block 100",
                    "single": "Single",
                    "profile": "Profile"}
db["Mode"] = db["Mode"].replace(mapping_dict_mode)

mapping_dict_dataset = {"taranaki": "Taranaki",
                       "geolink": "Geolink",
                       "teapot": "Teapot"}
db["Dataset"] = db["Dataset"].replace(mapping_dict_dataset)

db

## Creates graphics for each metric
Plots for all four dataset and at different missing patterns

In [None]:
metrics = db["Metric"].unique()

order_model = ["LOCF", "AE", "SAITS", "BRITS", "RF",
               "XGBoost", "Transformer", "U-Net"]#, "mRNN", "SVM"]

order_mode = ["Single", "Block 20", "Block 100", "Profile"]


palette = sns.color_palette(['darkgray', 'blueviolet', 'dodgerblue', 'lightskyblue',
                            'forestgreen', 'palegreen', 'darkorange', 'gold'])#, 'crimson', 'royalblue'])

plots = list()
sns.set(style="whitegrid", font_scale=2)

for i, metric in enumerate(metrics):
    g = sns.catplot(x='Mode',
                    y='mean_value',
                    hue='Model',
                    col='Dataset',
                    kind='bar',
                    data=db[db["Metric"] == metric],
                    order = order_mode,
                    hue_order = order_model,
                    height=6,
                    aspect=1.5,
                    sharey=True,
                    palette=palette,
                    col_wrap = 2)
    g.set_axis_labels('Mode', f'Mean {metric.upper()}')
    g.set_titles('{col_name}')
    sns.move_legend(g, "lower center", ncol=5, bbox_to_anchor=(.5, 1))
    plots.append(g)
    plots[i].savefig(f"figs/{metric}.png")

### Generating a .tex table (could be useful)

In [None]:
table = da.groupby(["Dataset", "Mode", "Metric", "Model"]).agg(mean_value=("Value", "mean"))

with open('output.tex', 'w') as f:
    f.write(table.to_latex())



## Preparation to compute the correlation between metrics

In [None]:
dclean = da[da.Mode != 'time']
dclean = dclean[["Dataset", "Mode", "Model", "Fold", 'Metric', 'Value']]
dclean.head()

In [None]:
dm2 = dclean.reset_index().groupby(["Dataset", "Mode", "Model", "Fold", 'Metric'])['Value'].aggregate('first').unstack()
dm2.head

## Computes the correlation between metrics and plot a heatmap below

In [None]:
correlation_matrix = dm2.corr()
correlation_matrix

In [None]:
# plt.figure(figsize=(10,8))

fig = sns.clustermap(correlation_matrix, annot=True, fmt=".2f", linewidths=.5, cmap="Blues", cbar_pos=None, dendrogram_ratio=0,
                    figsize=(10, 8),)
fig.fig.suptitle("Matriz de correlação das métricas", fontsize=18, y=1.02)
plt.show()
fig.savefig("metric_corr.png")