In [None]:
from pathlib import Path
from pprint import pformat, pprint
import logging
import json
import re
import sys
from math import ceil
from itertools import repeat, chain, product, cycle, combinations
import traceback
import math

import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
from seaborn import heatmap
import dynamic_yaml
import yaml

logging.basicConfig(format='%(levelname)-8s [%(filename)s] %(message)s',
                    level=logging.DEBUG)
matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.ERROR)
mpl.rcParams[u'font.sans-serif'] = ['simhei']
mpl.rcParams['axes.unicode_minus'] = False
%load_ext pycodestyle_magic

# Draw the training process

In [None]:
def baseline_gru_tr_proc_est(log_path_list: list, condition_dict: dict, regexp_res: bool = False,  plot_pic:bool = True):
    try:
        df = pd.DataFrame()
        for log_path in log_path_list:
            with open(log_path, "r") as source:
                log_dict = json.load(source)

            for k in log_dict.keys():
                locals()[k] = log_dict[k]
            corr_info = str(next(filter(lambda p: p.startswith("corr"), log_path.parts)))
            min_tr_loss = min(locals()["tr_loss_history"])
            min_tr_loss_edge_acc = locals()["tr_edge_acc_history"][np.argmin(np.array(locals()["tr_loss_history"]))]
            max_tr_edge_acc = max(locals()["tr_edge_acc_history"])
            max_val_edge_acc = max(locals()["val_edge_acc_history"])
            model_struct_str = log_dict.get('model_structure')
            record_fields = list(log_dict.keys()) + ["corr_info", "min_tr_loss", "min_tr_loss_edge_acc", "max_tr_edge_acc", "max_val_edge_acc"]
            if regexp_res:
                model_struct_info_fields = {"gru_l": "\(gru\): GRU\(\d+, \d+, num_layers=(?P<gru_l>\d+).*\)",
                                            "gru_h": "\(gru\): GRU\(\d+, (?P<gru_h>\d+), .+\)"}
                if model_struct_str:
                    for field, pattern in model_struct_info_fields.items():
                        match = re.search(pattern, model_struct_str)
                        if match:
                            locals()[field] = match.group(field)
                        else:
                            locals()[field] = None
                            logging.info(f"Can't detect {field}")
                    else:
                        locals()["gru_l"] = 1 if not locals()["gru_l"] else locals()["gru_l"]
                record_fields += list(model_struct_info_fields.keys())

            assert not(set(condition_dict.keys()) - set(locals().keys())), "one of condition_dict.keys() doesn't match the local variables if mts_corr_ad_est()"
            est_values_dict = locals()
            filtered_dict = dict(filter(lambda x: est_values_dict[x[0]] == x[1], condition_dict.items()))
            if filtered_dict == condition_dict:
                main_title_str = (f"{locals().get('corr_info')} with filt:{locals().get('filt_mode')}-{locals().get('filt_quan')} "
                                  f"and batch_size({locals().get('batch_size')}) "
                                  f"input to GRU with gru_l{locals().get('gru_l')}-gru_h{locals().get('gru_h')}\n"
                                  f"with drop: {locals().get('drop_p')} and loss_fns:{locals().get('loss_fns')}\n"
                                  f"min val-loss:{locals().get('min_val_loss'):8f} min tr-loss:{locals().get('min_tr_loss'):8f}")
                logging.info(f"file_name:{log_path.parts[-1]}")
                logging.info(f"file_path:{log_path.parts[2:-2]}")
                logging.info(f"main_title_str:\n{main_title_str}")
                comparison_dict = dict(filter(lambda x: x[0] in record_fields, locals().items()))
                df = pd.concat([df, pd.DataFrame([comparison_dict])])
                if plot_pic:
                    pass
                    plot_mts_corr_ad_tr_process(main_title=main_title_str, model_struct=model_struct_str, metrics_history={k:log_dict[k] for k in record_fields if "history" in k},
                                                best_epoch=locals()['best_val_epoch'], batches_per_epoch=locals()['batches_per_epoch'])
            else:
                continue
        else:
            df = df.reindex(["corr_info", "epochs", "batch_size", "gra_nodes_v_mode", "filt_mode", "filt_quan", "quan_discrete_bins",
                             "custom_discrete_bins", "seq_len", "loss_fns", "loss_weight", "drop_pos", "drop_p", "gru_l", "gru_h", "decoder",
                             "output_type", "output_bins", "target_mats_bins", "edge_acc_loss_atol", "two_ord_pred_prob_edge_accu_thres", 'best_val_epoch',
                             "min_tr_loss", "min_tr_loss_edge_acc", "max_tr_edge_acc", "min_val_loss", "max_val_edge_acc", "min_val_loss_edge_acc"], axis=1)
            df = df.sort_values(["batch_size", "seq_len", "gru_l", "gru_h", "drop_p"], ascending=False)
            df = df.reset_index(drop=True)
            df.style.set_caption('Info of Baseline_GRU model with different hyperparameters')
            pd.options.display.float_format = '{:.6f}'.format
            pd.set_option('display.max_columns', None)
            display(df)
    except Exception as e:
        error_class = e.__class__.__name__ #⬞取得錯誤類型
        detail = e.args[0]  #⬞取得詳細內容
        cl, exc, tb = sys.exc_info() #⬞取得Call⬞Stack
        last_call_stack = traceback.extract_tb(tb)[-1] #⬞取得Call⬞Stack的最後一筆資料↵
        file_name = last_call_stack[0] #⬞取得發生的檔案名稱↵
        line_num = last_call_stack[1] #⬞取得發生的行號↵
        func_name = last_call_stack[2] #⬞取得發生的函數名稱
        err_msg = "File \"{}\", line {}, in {}: [{}] {}".format(file_name, line_num, func_name, error_class, detail)
        logging.error(f"file:{log_path.parts[-1]}, path:{log_path}")
        logging.error(f"===\n{err_msg}")
        logging.error(f"===\n{traceback.extract_tb(tb)}")

    return df


def plot_mts_corr_ad_tr_process(main_title: str, model_struct: str, metrics_history: dict, best_epoch: int, batches_per_epoch: int):
    max_batch = batches_per_epoch * len(metrics_history['tr_loss_history'])  # epochs == len(metrics_history['tr_loss'])
    data_info_dict = [{"sub_title": 'train loss_history & edge_acc_history',
                       "data": {'tr_loss_history': metrics_history['tr_loss_history'],
                                'tr_edge_acc_history': metrics_history['tr_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      {"sub_title": 'val  loss_history & edge_acc_history',
                       "data": {'val_loss_history': metrics_history['val_loss_history'],
                                'val_edge_acc_history': metrics_history['val_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      # {"sub_title": 'train gradient_history',
                      #  "data": {"gru_gradient_history": metrics_history['gru_gradient_history'],
                      #           "fc_gradient_history": metrics_history['fc_gradient_history']},
                      #  "xticks": None,
                      #  "xlabel": "epochs"},
                      {"sub_title": f"model structure",
                       "data": str(model_struct)}]

    # figrue settings
    line_style = {"linewidth": 2, "alpha": 0.5}
    axvline_style = {"color": 'k', "linewidth": 5, "linestyle": '--', "alpha": 0.3}
    fig, axs = plt.subplot_mosaic("""
                                  ab
                                  cc
                                  """,
                                  figsize=(30, 20), gridspec_kw={'hspace': 0.2, 'wspace': 0.3})
    fig.suptitle(main_title, fontsize=30)

    try:
        for ax, data_plot in zip(axs.values(), data_info_dict):
            ax.set_title(data_plot["sub_title"], fontsize=30)
            ax.yaxis.offsetText.set_fontsize(18)
            ax.tick_params(axis='both', which='major', labelsize=24)
            if isinstance(data_plot["data"], dict) and data_plot.get("double_y"):
                for i, key in enumerate(data_plot["data"]):
                    if i == 0:
                        ax.plot(data_plot["data"][key], label=key, **line_style)
                        ax.set_ylabel(key, fontsize=24)
                        ax.legend(fontsize=18)
                    else:
                        new_ax = ax.twinx()
                        new_ax.plot(data_plot["data"][key], label=key, color='r')
                        new_ax.set_ylabel(key, color='r', fontsize=24)
                        new_ax.legend(fontsize=18)
                        new_ax.tick_params(axis='both', colors='r', which='major', labelsize=24)
            elif isinstance(data_plot["data"], dict):
                [ax.plot(data_plot["data"][key], label=key, **line_style) for key in data_plot["data"]]
                ax.legend(fontsize=18)
            elif isinstance(data_plot["data"], str):
                ax.annotate(text=f"{data_plot['data']}",
                            xy=(0.15, 0.5), bbox={'facecolor': 'green', 'alpha': 0.4, 'pad': 5},
                            fontsize=20, fontfamily='monospace', xycoords='axes fraction', va='center')
            else:
                ax.plot(data_plot["data"], **line_style)
            if pos_tuple := data_plot.get("axvline"):
                for x_pos in pos_tuple:
                    ax.axvline(x=x_pos, **axvline_style)
            if xlabel := data_plot.get("xlabel"):
                ax.set_xlabel(xlabel, fontsize=24)
            if t := data_plot.get("xticks"):
                ax.set_xticks(ticks=range(0, len(t["label"])*t["intv"], t["intv"]), labels=t["label"], rotation=45)
    except Exception as e:
        logging.error(f"Encounter error when draw figure of {data_plot['sub_title']}")
        raise e

    fig.tight_layout(rect=(0, 0, 0, 0))
    plt.show()
    plt.close()

In [None]:
baseline_gru_log_dir = Path("./save_models/class_baseline_gru_one_feature/archive/20230913/sp500_20112015_corr_ser_reg_std_corr_mat_negative_filtered-train_train/pearson/")
log_path_list1 = baseline_gru_log_dir.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = baseline_gru_log_dir.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = baseline_gru_log_dir.glob("./**/train_logs/*[!.ipynb_checkpoints]*[.json]")

model_tr_summary_df = baseline_gru_tr_proc_est(log_path_list=log_path_list1, condition_dict={}, regexp_res=False, plot_pic=True)
columns_containing_lists = model_tr_summary_df.apply(lambda col: isinstance(col.iloc[0], list)).iloc[list(model_tr_summary_df.apply(lambda col: isinstance(col.iloc[0], list)) == True)].index
columns_not_containing_lists = model_tr_summary_df.columns.difference(columns_containing_lists)
independent_variables_columns = model_tr_summary_df.loc[::, columns_not_containing_lists].nunique()[model_tr_summary_df.loc[::, columns_not_containing_lists].nunique() > 1].index
control_variables_columns = model_tr_summary_df.loc[::, columns_not_containing_lists].nunique().index.difference(independent_variables_columns)
for col in model_tr_summary_df.loc[::, columns_containing_lists]:
    col_ser = model_tr_summary_df.loc[::, col].apply(lambda x:str(x))
    if len(np.unique(col_ser)) > 1:
        independent_variables_columns = independent_variables_columns.append(pd.Index([col]))
    else:
        control_variables_columns = control_variables_columns.append(pd.Index([col]))
pd.set_option('display.max_colwidth', 80)
independent_variables_tr_summary_df = model_tr_summary_df.loc[::, independent_variables_columns].sort_index(axis=1)
control_variables_tr_summary_df = model_tr_summary_df.loc[0, control_variables_columns].sort_index(axis=0)
display(control_variables_tr_summary_df)
display(independent_variables_tr_summary_df)

# Observe prediction and labels during training and validation stage.

In [None]:
obs_baseline_gru_log_path = Path("./save_models/class_baseline_gru_without_self_corr/archive/20230908/sp500_20112015_corr_ser_reg_std_corr_mat_large_filtered_hrchy_10_cluster_label_last_v2-train_train/pearson/corr_s1_w50/train_logs/epoch_1499-20230908033654.json")
with open(obs_baseline_gru_log_path, "r") as f:
    log = json.load(f)

tr_preds_history = log['tr_preds_history']
tr_labels_history = log['tr_labels_history']
val_preds_history = log['val_preds_history']
val_labels_history = log['val_labels_history']
last_batch_size = len(tr_preds_history[0])
graph_adj_mat_size = len(tr_preds_history[0][0])
graph_adj_mat_items = ["ADP", "APH", "MDT", "PAYX"]  ## observe the index by operating `ython gen_corr_graph_data.py --data_implement <data_config.yaml['DATASETS'][DATASET_NAME]>`
# graph_adj_mat_items = ['NEM', 'ETR', 'INCY']  ## observe the index by operating `ython gen_corr_graph_data.py --data_implement <data_config.yaml['DATASETS'][DATASET_NAME]>`
num_epochs = len(tr_preds_history)
num_obs_epochs = 5
obs_epochs = np.linspace(0, num_epochs-1, num_obs_epochs, dtype="int")
obs_best_val_epoch = np.argmax(np.array(log['val_edge_acc_history']))
obs_batch_idx = 10

if math.sqrt(len(tr_preds_history[0][0])).is_integer():
    num_nodes = int(math.sqrt(len(tr_preds_history[0][0])))
    is_square_graph = True
else:
    num_nodes_minus_one = 1
    while (num_nodes_minus_one**2 + num_nodes_minus_one)/2 != len(tr_preds_history[0][0]):  # arithmetic progression sum formula
        num_nodes_minus_one += 1
    num_nodes = num_nodes_minus_one+1
    is_square_graph = False
assert is_square_graph == (graph_adj_mat_size == num_nodes**2), "when the graph is square graph, the size of graph should be num_nodes**2"
assert len(tr_preds_history) == len(tr_labels_history) and len(tr_preds_history) == len(val_preds_history) and len(tr_preds_history) == len(val_labels_history), "length of {tr_preds_history, tr_labels_history, val_preds_history, val_labels_history} should be equal to num_epochs"

In [None]:
assert isinstance(tr_preds_history[0][0][0], int), "Following operation require model's output are classification"
tr_preds_each_obs_epoch, tr_labels_each_obs_epoch, val_preds_each_obs_epoch, val_labels_each_obs_epoch = [None]*4
obs_vars_list = ["tr_preds_each_obs_epoch", "tr_labels_each_obs_epoch", "val_preds_each_obs_epoch", "val_labels_each_obs_epoch"]
history_record_list = [tr_preds_history, tr_labels_history, val_preds_history, val_labels_history]
tr_obs_df = pd.DataFrame()
val_obs_df = pd.DataFrame()
for epoch_idx in obs_epochs:
    for obs_var, history in zip(obs_vars_list, history_record_list):
        if is_square_graph:
            locals()[obs_var] = np.array(history[epoch_idx][obs_batch_idx]).reshape(num_nodes, num_nodes)
        else:
            tmp = np.zeros((num_nodes, num_nodes), dtype="int")
            ret_start = 0
            for i in range(1, num_nodes+1):
                ret_len = num_nodes-i
                ret_end = ret_start+ret_len
                tmp[i-1] = [0]*i+(history[epoch_idx][obs_batch_idx][ret_start:ret_end])
                ret_start = ret_end
            locals()[obs_var] = tmp

    tr_obs_df_each_obs_epoch = pd.DataFrame(np.concatenate([tr_preds_each_obs_epoch, tr_labels_each_obs_epoch], axis=0), columns=graph_adj_mat_items, index=graph_adj_mat_items*2)
    tr_obs_df = pd.concat([tr_obs_df, tr_obs_df_each_obs_epoch], axis=1)
    val_obs_df_each_obs_epoch = pd.DataFrame(np.concatenate([val_preds_each_obs_epoch, val_labels_each_obs_epoch], axis=0), columns=graph_adj_mat_items, index=graph_adj_mat_items*2)
    val_obs_df = pd.concat([val_obs_df, val_obs_df_each_obs_epoch], axis=1)



multi_level_col = pd.MultiIndex.from_product([[f"epoch_{epoch_idx}" for epoch_idx in obs_epochs], graph_adj_mat_items])
multi_level_idx = pd.MultiIndex.from_product([["predictions", "labels"], graph_adj_mat_items])
tr_obs_df.columns = multi_level_col
tr_obs_df.index = multi_level_idx
val_obs_df.columns = multi_level_col
val_obs_df.index = multi_level_idx
h_line = pd.DataFrame(["―"]*tr_obs_df.shape[1], columns=["h_line"], index=multi_level_col).T
tr_obs_df = pd.concat([tr_obs_df.iloc[:num_nodes, ], h_line, tr_obs_df.iloc[num_nodes:, ]], axis=0)
val_obs_df = pd.concat([val_obs_df.iloc[:num_nodes, ], h_line, val_obs_df.iloc[num_nodes:, ]], axis=0)
for i in range(1, num_obs_epochs):
    v_line_loc = i*num_nodes+(i-1)
    tr_obs_df.insert(v_line_loc, f"v_line_{i}", ["|"]*tr_obs_df.shape[0])
    val_obs_df.insert(v_line_loc, f"v_line_{i}", ["|"]*val_obs_df.shape[0])

tr_obs_df = tr_obs_df.style.set_caption(f"train prediction & labels of the graph in {obs_batch_idx}th of last_batch for epochs").set_table_styles([{'selector': 'caption',
                                                                                                                                         'props': [('color', 'red'),
                                                                                                                                                   ('font-size', '24px')]}])
val_obs_df = val_obs_df.style.set_caption(f"Validation prediction & labels of the graph in {obs_batch_idx}th of last_batch for epochs").set_table_styles([{'selector': 'caption',
                                                                                                                                                'props': [('color', 'red'),
                                                                                                                                                          ('font-size', '24px')]}])
display(tr_obs_df)
display(val_obs_df)

## Observe the heatmap of validation dataset

In [None]:
val_batch_size = len(val_preds_history[0])
best_epoch_val_preds = np.array(val_preds_history[obs_best_val_epoch])
best_epoch_val_labels = np.array(val_labels_history[obs_best_val_epoch])
total_val_data_confusion_matrix = pd.DataFrame(confusion_matrix(best_epoch_val_labels.reshape(-1), best_epoch_val_preds.reshape(-1), labels=[0, 1, 2]), columns=range(-1,2), index=range(-1,2))
plt.figure(figsize = (10,10))
plt.rcParams.update({'font.size': 44})
ax = plt.gca()
heatmap(total_val_data_confusion_matrix, annot=True, ax=ax, fmt='g')
ax.set(xlabel="Prediction", ylabel="Ground Truth")
plt.show()
plt.close()
#pred_incorrect_mask = (best_epoch_val_preds != best_epoch_val_labels)
#print(pred_incorrect_mask.sum(), best_epoch_val_preds.shape, best_epoch_val_preds.size)
#print(best_epoch_val_preds[pred_incorrect_mask])
#print(best_epoch_val_labels[pred_incorrect_mask])

## Observe the correlation type distribution

In [None]:
items_pairs = list(product(graph_adj_mat_items, repeat=2)) if is_square_graph else list(combinations(graph_adj_mat_items, 2))
val_graph_corr_type_distribution = pd.DataFrame(columns=["preds", "labels"])
val_preds_graph_corr_type_pie_plot_ser = pd.Series()
val_labels_graph_corr_type_pie_plot_ser = pd.Series()
for pairs, graph_corr_type in zip(cycle([items_pairs]), product((-1, 0, 1), repeat=graph_adj_mat_size)):
    val_preds_graph_corr_type_count = np.apply_along_axis(lambda x: all(x==graph_corr_type), 1, (best_epoch_val_preds-1)).sum()
    val_labels_graph_corr_type_count = np.apply_along_axis(lambda x: all(x==graph_corr_type), 1, (best_epoch_val_labels-1)).sum()
    val_graph_corr_type_distribution = pd.concat([val_graph_corr_type_distribution, pd.DataFrame([[val_preds_graph_corr_type_count, val_labels_graph_corr_type_count]], index=[str(graph_corr_type)], columns=["preds", "labels"])])
    if val_labels_graph_corr_type_count:
        val_labels_graph_corr_type_pie_plot_ser = pd.concat([val_labels_graph_corr_type_pie_plot_ser, pd.Series([val_labels_graph_corr_type_count], index=[str(dict(zip(pairs, graph_corr_type)))])])
    if val_preds_graph_corr_type_count:
        val_preds_graph_corr_type_pie_plot_ser = pd.concat([val_preds_graph_corr_type_pie_plot_ser, pd.Series([val_preds_graph_corr_type_count], index=[str(dict(zip(pairs, graph_corr_type)))])])

fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(30, 20))
ax[0].pie(val_labels_graph_corr_type_pie_plot_ser.values, labels=val_labels_graph_corr_type_pie_plot_ser.index, autopct='%1.1f%%', textprops={'fontsize': 20})
ax[0].set_title("Val labels distribution", fontsize=44)
ax[1].pie(val_preds_graph_corr_type_pie_plot_ser.values, labels=val_preds_graph_corr_type_pie_plot_ser.index, autopct='%1.1f%%', textprops={'fontsize': 20})
ax[1].set_title("Val preds distribution", fontsize=44)
plt.show()
plt.close()
val_graph_corr_type_distribution.index.name = str(items_pairs)
display_mask = np.logical_or((val_graph_corr_type_distribution.loc[::, ["preds"]]!=0).values, (val_graph_corr_type_distribution.loc[::, ["labels"]]!=0).values)
print(display_mask.shape)
display(val_graph_corr_type_distribution.loc[display_mask, ::])
display(val_graph_corr_type_distribution.iloc[:9, ::])
val_graph_corr_type_distribution.index.name = ""
display(val_graph_corr_type_distribution.iloc[9:18, ::])
display(val_graph_corr_type_distribution.iloc[18:, ::])

## Observe which data in the validation dataset has been predicted incorrectly

## Observe the correctness ratio of each correlation type

In [None]:
val_info_dict = {"negative": {"mask": (best_epoch_val_labels == 0),
                              "num":(best_epoch_val_labels == 0).sum()},
                 "no_corr": {"mask": (best_epoch_val_labels == 1),
                             "num": (best_epoch_val_labels == 1).sum()},
                 "positive": {"mask": (best_epoch_val_labels == 2),
                              "num": (best_epoch_val_labels == 2).sum()}}
pred_correct_mask = (best_epoch_val_preds == best_epoch_val_labels)
assert sum([v["num"] >= 0 for v in val_info_dict.values()]) == 3
assert sum([v["num"] for v in val_info_dict.values()]) == best_epoch_val_labels.size, "the sum of mask should be same with all"
for corr_type, info in val_info_dict.items():
    val_info_dict[corr_type]["correct_ratio"] = np.logical_and(info["mask"], pred_correct_mask).sum()/info["num"] if info["num"] > 0 else None
    logging.info(f"corr_type:{corr_type}, num_corr_type:{val_info_dict[corr_type]['num']}, correct_ratio:{val_info_dict[corr_type]['correct_ratio']}")