In [None]:
from pathlib import Path
from pprint import pformat, pprint
import logging
import json
import re
import sys
from math import ceil
from itertools import repeat, chain, product
import traceback

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import dynamic_yaml
import yaml

logging.basicConfig(format='%(levelname)-8s [%(filename)s] %(message)s',
                    level=logging.DEBUG)
matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.ERROR)
mpl.rcParams[u'font.sans-serif'] = ['simhei']
mpl.rcParams['axes.unicode_minus'] = False
%load_ext pycodestyle_magic

# Draw the training process

In [None]:
def baseline_gru_tr_proc_est(log_path_list: list, condition_dict: dict, regexp_res: bool = False,  plot_pic:bool = True):
    try:
        df = pd.DataFrame()
        for log_path in log_path_list:
            with open(log_path, "r") as source:
                log_dict = json.load(source)

            for k in log_dict.keys():
                locals()[k] = log_dict[k]
            corr_info = str(next(filter(lambda p: p.startswith("corr"), log_path.parts)))
            min_tr_loss = min(locals()["tr_loss_history"])
            min_tr_loss_edge_acc = locals()["tr_edge_acc_history"][np.argmin(np.array(locals()["tr_loss_history"]))]
            max_tr_edge_acc = max(locals()["tr_edge_acc_history"])
            max_val_edge_acc = max(locals()["val_edge_acc_history"])
            model_struct_str = log_dict.get('model_structure')
            record_fields = list(log_dict.keys()) + ["corr_info", "min_tr_loss", "min_tr_loss_edge_acc", "max_tr_edge_acc", "max_val_edge_acc"]
            if regexp_res:
                model_struct_info_fields = {"gru_l": "\(gru\): GRU\(\d+, \d+, num_layers=(?P<gru_l>\d+).*\)",
                                            "gru_h": "\(gru\): GRU\(\d+, (?P<gru_h>\d+), .+\)"}
                if model_struct_str:
                    for field, pattern in model_struct_info_fields.items():
                        match = re.search(pattern, model_struct_str)
                        if match:
                            locals()[field] = match.group(field)
                        else:
                            locals()[field] = None
                            logging.info(f"Can't detect {field}")
                    else:
                        locals()["gru_l"] = 1 if not locals()["gru_l"] else locals()["gru_l"]
                record_fields += list(model_struct_info_fields.keys())

            assert not(set(condition_dict.keys()) - set(locals().keys())), "one of condition_dict.keys() doesn't match the local variables if mts_corr_ad_est()"
            est_values_dict = locals()
            filtered_dict = dict(filter(lambda x: est_values_dict[x[0]] == x[1], condition_dict.items()))
            if filtered_dict == condition_dict:
                main_title_str = (f"{locals().get('corr_info')} with filt:{locals().get('filt_mode')}-{locals().get('filt_quan')} "
                                  f"and batch_size({locals().get('batch_size')}) "
                                  f"input to GRU with gru_l{locals().get('gru_l')}-gru_h{locals().get('gru_h')}\n"
                                  f"with drop: {locals().get('drop_p')} and loss_fns:{locals().get('loss_fns')}\n"
                                  f"min val-loss:{locals().get('min_val_loss'):8f} min tr-loss:{locals().get('min_tr_loss'):8f}")
                logging.info(f"file_name:{log_path.parts[-1]}")
                logging.info(f"file_path:{log_path.parts[2:-2]}")
                logging.info(f"main_title_str:\n{main_title_str}")
                comparison_dict = dict(filter(lambda x: x[0] in record_fields, locals().items()))
                df = pd.concat([df, pd.DataFrame([comparison_dict])])
                if plot_pic:
                    pass
                    plot_mts_corr_ad_tr_process(main_title=main_title_str, model_struct=model_struct_str, metrics_history={k:log_dict[k] for k in record_fields if "history" in k},
                                                best_epoch=locals()['best_val_epoch'], batches_per_epoch=locals()['batches_per_epoch'])
            else:
                continue
        else:
            df = df.reindex(["corr_info", "epochs", "batch_size", "gra_nodes_v_mode", "filt_mode", "filt_quan", "quan_discrete_bins",
                             "custom_discrete_bins", "seq_len", "loss_fns", "drop_pos", "drop_p", "gru_l", "gru_h", "decoder",
                             "output_type", "output_bins", "target_mats_bins", "edge_acc_loss_atol", "two_ord_pred_prob_edge_accu_thres", "min_tr_loss", "min_tr_loss_edge_acc",
                             "max_tr_edge_acc", "min_val_loss", "max_val_edge_acc", "min_val_loss_edge_acc"], axis=1)
            df = df.sort_values(["batch_size", "seq_len", "gru_l", "gru_h", "drop_p"], ascending=False)
            df = df.reset_index(drop=True)
            df.style.set_caption('Info of Baseline_GRU model with different hyperparameters')
            pd.options.display.float_format = '{:.6f}'.format
            pd.set_option('display.max_columns', None)
            display(df)
    except Exception as e:
        error_class = e.__class__.__name__ #⬞取得錯誤類型
        detail = e.args[0]  #⬞取得詳細內容
        cl, exc, tb = sys.exc_info() #⬞取得Call⬞Stack
        last_call_stack = traceback.extract_tb(tb)[-1] #⬞取得Call⬞Stack的最後一筆資料↵
        file_name = last_call_stack[0] #⬞取得發生的檔案名稱↵
        line_num = last_call_stack[1] #⬞取得發生的行號↵
        func_name = last_call_stack[2] #⬞取得發生的函數名稱
        err_msg = "File \"{}\", line {}, in {}: [{}] {}".format(file_name, line_num, func_name, error_class, detail)
        logging.error(f"file:{log_path.parts[-1]}, path:{log_path}")
        logging.error(f"===\n{err_msg}")
        logging.error(f"===\n{traceback.extract_tb(tb)}")

    return df


def plot_mts_corr_ad_tr_process(main_title: str, model_struct: str, metrics_history: dict, best_epoch: int, batches_per_epoch: int):
    max_batch = batches_per_epoch * len(metrics_history['tr_loss_history'])  # epochs == len(metrics_history['tr_loss'])
    data_info_dict = [{"sub_title": 'train loss_history & edge_acc_history',
                       "data": {'tr_loss_history': metrics_history['tr_loss_history'],
                                'tr_edge_acc_history': metrics_history['tr_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      {"sub_title": 'val  loss_history & edge_acc_history',
                       "data": {'val_loss_history': metrics_history['val_loss_history'],
                                'val_edge_acc_history': metrics_history['val_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      # {"sub_title": 'train gradient_history',
                      #  "data": {"gru_gradient_history": metrics_history['gru_gradient_history'],
                      #           "fc_gradient_history": metrics_history['fc_gradient_history']},
                      #  "xticks": None,
                      #  "xlabel": "epochs"},
                      {"sub_title": f"model structure",
                       "data": str(model_struct)}]

    # figrue settings
    line_style = {"linewidth": 2, "alpha": 0.5}
    axvline_style = {"color": 'k', "linewidth": 5, "linestyle": '--', "alpha": 0.3}
    fig, axs = plt.subplot_mosaic("""
                                  ab
                                  cc
                                  """,
                                  figsize=(30, 20), gridspec_kw={'hspace': 0.2, 'wspace': 0.3})
    fig.suptitle(main_title, fontsize=30)

    try:
        for ax, data_plot in zip(axs.values(), data_info_dict):
            ax.set_title(data_plot["sub_title"], fontsize=30)
            ax.yaxis.offsetText.set_fontsize(18)
            ax.tick_params(axis='both', which='major', labelsize=24)
            if isinstance(data_plot["data"], dict) and data_plot.get("double_y"):
                for i, key in enumerate(data_plot["data"]):
                    if i == 0:
                        ax.plot(data_plot["data"][key], label=key, **line_style)
                        ax.set_ylabel(key, fontsize=24)
                        ax.legend(fontsize=18)
                    else:
                        new_ax = ax.twinx()
                        new_ax.plot(data_plot["data"][key], label=key, color='r')
                        new_ax.set_ylabel(key, color='r', fontsize=24)
                        new_ax.legend(fontsize=18)
                        new_ax.tick_params(axis='both', colors='r', which='major', labelsize=24)
            elif isinstance(data_plot["data"], dict):
                [ax.plot(data_plot["data"][key], label=key, **line_style) for key in data_plot["data"]]
                ax.legend(fontsize=18)
            elif isinstance(data_plot["data"], str):
                ax.annotate(text=f"{data_plot['data']}",
                            xy=(0.15, 0.5), bbox={'facecolor': 'green', 'alpha': 0.4, 'pad': 5},
                            fontsize=20, fontfamily='monospace', xycoords='axes fraction', va='center')
            else:
                ax.plot(data_plot["data"], **line_style)
            if pos_tuple := data_plot.get("axvline"):
                for x_pos in pos_tuple:
                    ax.axvline(x=x_pos, **axvline_style)
            if xlabel := data_plot.get("xlabel"):
                ax.set_xlabel(xlabel, fontsize=24)
            if t := data_plot.get("xticks"):
                ax.set_xticks(ticks=range(0, len(t["label"])*t["intv"], t["intv"]), labels=t["label"], rotation=45)
    except Exception as e:
        logging.error(f"Encounter error when draw figure of {data_plot['sub_title']}")
        raise e

    fig.tight_layout(rect=(0, 0, 0, 0))
    plt.show()
    plt.close()

In [None]:
baseline_gru_log_dir = Path("./save_models/class_baseline_gru/archive/20230827/sp500_20082017_corr_ser_reg_std_corr_mat_hrchy_10_cluster_label_last-train_train/pearson/")
log_path_list1 = baseline_gru_log_dir.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = baseline_gru_log_dir.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = baseline_gru_log_dir.glob("./**/train_logs/*[!.ipynb_checkpoints]*[.json]")

# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "tr_batch": 32, "gra_enc_l": 1, "gra_enc_h": 4, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gra_enc_h": 16, "filt_mode": "keep_strong", "graph_enc":"GineEncoder"}, plot_pic=True)
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "loss_fn": str(['MSELoss()', 'discr_loss'])}, plot_pic=True)
# model_tr_summary_df =  baseline_gru_tr_proc_est(log_path_list1, {}, plot_pic=True)
# filt_model_tr_summary_df = model_tr_summary_df.loc[::, ["corr_info", "custom_discrete_bins", "seq_len", "gru_l", "gru_h", "decoder", "output_type", "target_mats_bins", "edge_acc_loss_atol", "max_tr_edge_acc", "max_val_edge_acc"]]
# filt_model_tr_summary_df

model_tr_summary_df = baseline_gru_tr_proc_est(log_path_list=log_path_list1, condition_dict={}, regexp_res=False, plot_pic=True)
columns_containing_lists = model_tr_summary_df.apply(lambda col: isinstance(col.iloc[0], list)).iloc[list(model_tr_summary_df.apply(lambda col: isinstance(col.iloc[0], list)) == True)].index
columns_not_containing_lists = model_tr_summary_df.columns.difference(columns_containing_lists)
independent_variables_columns = model_tr_summary_df.loc[::, columns_not_containing_lists].nunique()[model_tr_summary_df.loc[::, columns_not_containing_lists].nunique() > 1].index
control_variables_columns = model_tr_summary_df.loc[::, columns_not_containing_lists].nunique().index.difference(independent_variables_columns)
for col in model_tr_summary_df.loc[::, columns_containing_lists]:
    if len(np.unique(model_tr_summary_df.loc[::, col].values)) > 1:
        independent_variables_columns = independent_variables_columns.append(pd.Index([col]))
    else:
        control_variables_columns = control_variables_columns.append(pd.Index([col]))

independent_variables_tr_summary_df = model_tr_summary_df.loc[::, independent_variables_columns].sort_index(axis=1)
control_variables_tr_summary_df = model_tr_summary_df.loc[0, control_variables_columns].sort_index(axis=0)
display(control_variables_tr_summary_df)
display(independent_variables_tr_summary_df)

In [None]:
baseline_gru_log_dir2 = Path("./save_models/baseline_gru/archive/20230621/sp500_20082017_corr_ser_reg_std_corr_mat_hrchy_10_cluster_label_half_mix-train_train/")
log_path_list1 = baseline_gru_log_dir2.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = baseline_gru_log_dir2.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = baseline_gru_log_dir2.glob("./**/train_logs/*[.json]")
model_tr_summary_df =  baseline_gru_tr_proc_est(log_path_list1, {"gru_h":"80", "seq_len":5}, plot_pic=False)
filt_model_tr_summary_df = model_tr_summary_df.loc[::, ["corr_info", "seq_len", "gru_l", "gru_h", "max_tr_edge_acc", "max_val_edge_acc"]]
filt_model_tr_summary_df

In [None]:
baseline_gru_log_dir3 = Path("./save_models/baseline_gru/sp500_20082017_corr_ser_reg_std_corr_mat_hrchy_10_cluster_label_half_mix-train_train/")
log_path_list1 = baseline_gru_log_dir3.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = baseline_gru_log_dir3.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = baseline_gru_log_dir3.glob("./**/train_logs/*[!.ipynb_checkpoints]*[.json]")
model_tr_summary_df =  baseline_gru_tr_proc_est(log_path_list1, {}, plot_pic=True)
filt_model_tr_summary_df = model_tr_summary_df.loc[::, ["corr_info", "seq_len", "gru_l", "gru_h", "max_tr_edge_acc", "max_val_edge_acc"]]
filt_model_tr_summary_df