In [None]:
from pathlib import Path
from pprint import pformat, pprint
import logging
import json
import re
import sys
from math import ceil
from itertools import repeat, chain, product
import traceback

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import dynamic_yaml
import yaml

logging.basicConfig(format='%(levelname)-8s [%(filename)s] %(message)s',
                    level=logging.DEBUG)
matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.ERROR)
mpl.rcParams[u'font.sans-serif'] = ['simhei']
mpl.rcParams['axes.unicode_minus'] = False
%load_ext pycodestyle_magic

# Draw the training process

In [None]:
def mts_corr_ad_tr_proc_est(log_path_list: list, condition_dict: dict,  plot_pic:bool = True):
    try:
        df = pd.DataFrame()
        for log_path in log_path_list:
            with open(log_path, "r") as source:
                log_dict = json.load(source)

            for k in log_dict.keys():
                locals()[k] = log_dict[k]
            model_struct_info_fields = {"opt_lr": ".?initial_lr: (?P<opt_lr>\d\.\d+)\n",
                                        "opt_weight_decay": ".?weight_decay: (?P<opt_weight_decay>\d\.\d+)\n",
                                        "gru_l": "\(gru1\): GRU\(\d+, \d+, num_layers=(?P<gru_l>\d+)\)",
                                        "gru_h": "\(gru1\): GRU\(\d+, (?P<gru_h>\d+).+\)",
                                        "gra_enc_h": "(\(\d\)\:\s.*Conv.*\n.*)(out_features\=)(?P<gra_enc_h>\d*)",
                                        "gra_enc_l": "└─\(.\)GINEConv",
                                        "decoder": "\(decoder\): (?P<decoder>\w+)\("}
            if model_struct_str := log_dict.get('model_structure')+log_dict.get('optimizer'):
                for field, pattern in model_struct_info_fields.items():
                    if field == "gra_enc_l":
                        locals()[field] = len(re.findall(pattern, model_struct_str))
                        continue
                    elif field == "decoder":
                        locals()[field] = re.search(pattern, model_struct_str).group(field)
                        continue
                    match = re.search(pattern, model_struct_str)
                    if match:
                        locals()[field] = float(match.group(field))
                    else:
                        locals()[field] = None
                        logging.info(f"Can't detect {field}")

            # if embeds_history := log_dict.get("graph_embeds_history"):
            #     pred_embeds_history = embeds_history.get('pred_graph_embeds', [0])
            #     y_embeds_history = embeds_history.get('y_graph_embeds', [0])
            #     gra_emb_size = int(locals().get('gra_enc_l') * locals().get('gra_enc_h'))
            #     embeds_his_dict = {"pred_embeds": np.array(pred_embeds_history[:locals().get('batches_per_epoch') * 2]\
            #                                                    + [([np.nan] * gra_emb_size) for _ in range(20)]\
            #                                                    + pred_embeds_history[-locals()['batches_per_epoch'] * 2:]),
            #                        "y_embeds": np.array(y_embeds_history[:locals().get('batches_per_epoch') * 2]\
            #                                             + [([np.nan] * gra_emb_size) for _ in range(20)]\
            #                                             + y_embeds_history[-locals().get('batches_per_epoch') * 2:]),
            #                        "last_y_embeds": y_embeds_history[-locals().get('batches_per_epoch') * 5:]}

            corr_info = str(next(filter(lambda p: p.startswith("corr"), log_path.parts)))
            min_tr_loss = min(locals()["tr_loss_history"])
            min_tr_loss_edge_acc = locals()["tr_edge_acc_history"][np.argmin(np.array(locals()["tr_loss_history"]))]
            max_tr_edge_acc = max(locals()["tr_edge_acc_history"])
            max_val_edge_acc = max(locals()["val_edge_acc_history"])
            record_fields = list(log_dict.keys()) + ["dataset", "corr_info", "min_tr_loss", "min_tr_loss_edge_acc", "max_tr_edge_acc", "max_val_edge_acc"] + list(model_struct_info_fields.keys())
            est_values_dict = locals()
            assert not(set(condition_dict.keys()) - set(record_fields)), "one of condition_dict.keys() doesn't match the record_fields if mts_corr_ad_est()"
            filtered_dict = dict(filter(lambda x: est_values_dict[x[0]] == x[1], condition_dict.items()))
            if filtered_dict == condition_dict:
                main_title_str = (f"Nodes mode-{locals().get('graph_nodes_v_mode')} + {locals().get('corr_info')} with filt:{locals().get('filt_mode')}-{locals().get('filt_quan')}-{locals().get('discrete_bin')}"
                                  f"and batch_size({locals().get('batch_size')}) and seq_len({locals().get('seq_len')}) "
                                  f"input to MTSCorrAD with gra_enc_aggr{locals().get('gra_enc_aggr')}-gra_enc_l{locals().get('gra_enc_l'):.0f}-gra_enc_h{locals().get('gra_enc_h'):.0f}"
                                  f"-gru_l{locals().get('gru_l')}-gru_h{locals().get('gru_h')}\n"
                                  f"with drop: {locals().get('drop_p')} and loss_fns:{locals().get('loss_fns')}\n"
                                  f"min val-loss:{locals().get('min_val_loss'):8f} min tr-loss:{locals().get('min_tr_loss'):8f}\n"
                                  f"max_tr_edge_acc:{locals().get('max_tr_edge_acc')}, max_val_edge_acc:{locals().get('max_val_edge_acc')}")
                logging.info(f"file_name:{log_path.parts[-1]}")
                logging.info(f"file_path:{log_path.parts[2:-2]}")
                logging.info(f"main_title_str:\n{main_title_str}")
                comparison_dict = dict(filter(lambda x: x[0] in record_fields, locals().items()))
                df = pd.concat([df, pd.DataFrame([comparison_dict])])
                if plot_pic:
                    plot_mts_corr_ad_tr_process(main_title=main_title_str, model_struct=model_struct_str,
                                                metrics_history={k:log_dict[k] for k in record_fields if "history" in k},
                                                best_epoch=locals()['best_val_epoch'], batches_per_epoch=locals()['batches_per_epoch'],
                                                gra_emb_size=int(locals().get('gra_enc_l')*locals().get('gra_enc_h')))
            else:
                continue
        else:
            df = df.reindex(["corr_info", "epochs", "batch_size", "graph_nodes_v_mode", "filt_mode", "filt_quan", "quan_discrete_bins",
                             "custom_discrete_bins", "seq_len", "loss_fns", "optimizer", "opt_lr", "opt_weight_decay", "gra_enc_weight_l2_reg_lambda",
                             "drop_pos", "drop_p", "graph_enc", "gra_enc_aggr", "gra_enc_l", "gra_enc_h", "gru_l", "gru_h", "decoder",
                             "output_type", "output_bins", "target_mats_bins", "edge_acc_loss_atol", "min_tr_loss", "min_tr_loss_edge_acc",
                             "max_tr_edge_acc", "min_val_loss", "max_val_edge_acc", "min_val_loss_edge_acc"], axis=1)
            df = df.sort_values(["batch_size", "seq_len", "gra_enc_aggr", "gra_enc_l", "gra_enc_h", "gru_l", "gru_h",
                                 "filt_mode", "filt_quan", "graph_enc", "graph_nodes_v_mode", "loss_fns", "opt_lr", "drop_p"], ascending=False)
            df = df.reset_index(drop=True)
            df.style.set_caption('Info of MTSCorrAD model with different hyperparameters')
            pd.options.display.float_format = '{:.6f}'.format
            pd.set_option('display.max_columns', None)
            display(df)
    except Exception as e:
        error_class = e.__class__.__name__ #⬞取得錯誤類型
        detail = e.args[0]  #⬞取得詳細內容
        cl, exc, tb = sys.exc_info() #⬞取得Call⬞Stack
        last_call_stack = traceback.extract_tb(tb)[-1] #⬞取得Call⬞Stack的最後一筆資料↵
        file_name = last_call_stack[0] #⬞取得發生的檔案名稱↵
        line_num = last_call_stack[1] #⬞取得發生的行號↵
        func_name = last_call_stack[2] #⬞取得發生的函數名稱
        err_msg = "File \"{}\", line {}, in {}: [{}] {}".format(file_name, line_num, func_name, error_class, detail)
        logging.error(f"file:{log_path.parts[-1]}, path:{log_path}")
        logging.error(f"===\n{err_msg}")
        logging.error(f"===\n{traceback.extract_tb(tb)}")
    return df


def plot_mts_corr_ad_tr_process(main_title: str, model_struct: str, metrics_history: dict, best_epoch: int, batches_per_epoch: int, gra_emb_size: int):
    # pred_embeds, y_embeds, last_y_embeds = embeds_history['pred_gra_embeds_history'], embeds_history['y_gra_embeds_history'], embeds_history['last_y_embeds']
    pred_embeds = np.array(metrics_history['pred_gra_embeds_history'][0][:batches_per_epoch*2]+[([np.nan]*gra_emb_size) for _ in range(20)]+metrics_history['pred_gra_embeds_history'][-1][-batches_per_epoch*2:])
    y_embeds = np.array(metrics_history['y_gra_embeds_history'][0][:batches_per_epoch*2]+[([np.nan]*gra_emb_size) for _ in range(20)]+metrics_history['y_gra_embeds_history'][-1][-batches_per_epoch*2:])
    last_y_embeds = metrics_history['y_gra_embeds_history'][-1][-batches_per_epoch*5:]

    max_batch = batches_per_epoch * len(metrics_history['tr_loss_history'])  # epochs == len(metrics_history['tr_loss'])
    xticks_intv = {"loss": 20,
                   "fr_ls_embeds": int(len(y_embeds)/10)}
    loss_xticks_label = list(range(max(0, best_epoch-100), max(201, best_epoch+101), xticks_intv["loss"]))
    fr_ls_embeds_xticks_label = list(range(0, batches_per_epoch*2, xticks_intv["fr_ls_embeds"])) + [" "] + list(range(max_batch-xticks_intv["fr_ls_embeds"]*4, max_batch+1, xticks_intv["fr_ls_embeds"]))
    data_info_dict = [{"sub_title": 'train loss_history & edge_acc_history',
                       "data": {'tr_loss_history': metrics_history['tr_loss_history'],
                                'tr_edge_acc_history': metrics_history['tr_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      {"sub_title": 'val  loss_history & edge_acc_history',
                       "data": {'val_loss_history': metrics_history['val_loss_history'],
                                'val_edge_acc_history': metrics_history['val_edge_acc_history']},
                       "xticks": None,
                       "xlabel": "epochs",
                       "double_y": True},
                      # {"sub_title": f"train_loss_history-epoch{(max(0, best_epoch-100), max(200, best_epoch+100))}",
                      #  "data": metrics_history['tr_loss_history'][max(0, best_epoch-100):max(201, best_epoch+101)],
                      #  "xticks": {"label": loss_xticks_label, "intv": xticks_intv['loss']},
                      #  "xlabel": "epochs"},
                      # {"sub_title": f"val_loss_history-epoch{(max(0, best_epoch-100), max(200, best_epoch+100))}",
                      #  "data": metrics_history['val_loss_history'][max(0, best_epoch-100):max(201, best_epoch+101)],
                      #  "xticks": {"label": loss_xticks_label, "intv": xticks_intv['loss']},
                      #  "xlabel": "epochs"},
                      # {"sub_title": f'pred_embeds, embeds size:[{pred_embeds.shape[1]}]',
                      #  "data": pred_embeds,
                      #  "xticks": {"label": fr_ls_embeds_xticks_label, "intv": xticks_intv["fr_ls_embeds"]},
                      #  "xlabel": "batches",
                      #  "axvline": (batches_per_epoch, batches_per_epoch*3+20)},
                      # {"sub_title": f'y_embeds, embeds size:[{y_embeds.shape[1]}]',
                      #  "data": y_embeds,
                      #  "xticks": {"label": fr_ls_embeds_xticks_label, "intv": xticks_intv["fr_ls_embeds"]},
                      #  "xlabel": "batches",
                      #  "axvline": (batches_per_epoch, batches_per_epoch*3+20)},
                      # {"sub_title": f"y_embeds in last five epochs; embeds size:{y_embeds.shape[1]}",
                      #  "data": last_y_embeds,
                      #  "xticks": {"label": range(max_batch - batches_per_epoch * 5, max_batch + 1, batches_per_epoch), "intv": batches_per_epoch},
                      #  "xlabel": "batches",
                      #  "axvline": [i*batches_per_epoch for i in range(1, 5)]},
                      {"sub_title": f"model structure",
                       "data": str(model_struct)}]

    # figrue settings
    line_style = {"linewidth": 2, "alpha": 0.5}
    axvline_style = {"color": 'k', "linewidth": 5, "linestyle": '--', "alpha": 0.3}
    fig, axs = plt.subplot_mosaic("""
                                  ab
                                  cc
                                  cc
                                  cc
                                  """,
                                  figsize=(25, 40))
    fig.suptitle(main_title, fontsize=30)

    try:
        for ax, data_plot in zip(axs.values(), data_info_dict):
            ax.set_title(data_plot["sub_title"], fontsize=30)
            ax.yaxis.offsetText.set_fontsize(18)
            ax.tick_params(axis='both', which='major', labelsize=24)
            if isinstance(data_plot["data"], dict) and data_plot.get("double_y"):
                for i, key in enumerate(data_plot["data"]):
                    if i == 0:
                        ax.plot(data_plot["data"][key], label=key, **line_style)
                        ax.set_ylabel(key, fontsize=24)
                        ax.legend(fontsize=18)
                    else:
                        new_ax = ax.twinx()
                        new_ax.plot(data_plot["data"][key], label=key, color='r')
                        new_ax.set_ylabel(key, color='r', fontsize=24)
                        new_ax.legend(fontsize=18)
                        new_ax.tick_params(axis='both', colors='r', which='major', labelsize=24)
            elif isinstance(data_plot["data"], dict):
                [ax.plot(data_plot["data"][key], label=key, **line_style) for key in data_plot["data"]]
                ax.legend(fontsize=18)
            elif isinstance(data_plot["data"], str):
                ax.annotate(text=f"{data_plot['data']}",
                            xy=(0.15, 0.5), bbox={'facecolor': 'green', 'alpha': 0.4, 'pad': 5},
                            fontsize=20, fontfamily='monospace', xycoords='axes fraction', va='center')
            else:
                ax.plot(data_plot["data"], **line_style)
            if pos_tuple := data_plot.get("axvline"):
                for x_pos in pos_tuple:
                    ax.axvline(x=x_pos, **axvline_style)
            if xlabel := data_plot.get("xlabel"):
                ax.set_xlabel(xlabel, fontsize=24)
            if t := data_plot.get("xticks"):
                ax.set_xticks(ticks=range(0, len(t["label"])*t["intv"], t["intv"]), labels=t["label"], rotation=45)
    except Exception as e:
        logging.error(f"Encounter error when draw figure of {data_plot['sub_title']}")
        raise e

    fig.tight_layout(rect=(0, 0, 1, 0.97))
    plt.show()


In [None]:
mts_corr_model_log_dir = Path("./save_models/class_mts_corr_ad_model/archive/20230805/sp500_20082017_corr_ser_reg_std_corr_mat_hrchy_10_cluster_label_7th-train_train/pearson")
log_path_list1 = mts_corr_model_log_dir.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = mts_corr_model_log_dir.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = mts_corr_model_log_dir.glob("./**/train_logs/*[!.ipynb_checkpoints]*[.json]")

# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "tr_batch": 32, "gra_enc_l": 1, "gra_enc_h": 4, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gra_enc_h": 16, "filt_mode": "keep_strong", "graph_enc":"GineEncoder"}, plot_pic=True)
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "loss_fns": str(['MSELoss()', 'discr_loss'])}, plot_pic=True)
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list2, {"corr_info": "corr_s1_w10", "loss_fns": str(['MSELoss()'])}, plot_pic=True)
model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list2, {}, plot_pic=False)
filt_model_tr_summary_df = model_tr_summary_df.loc[::, ["seq_len", 'gra_enc_weight_l2_reg_lambda', "gra_enc_l", "gra_enc_h", "max_tr_edge_acc", "max_val_edge_acc"]]
filt_model_tr_summary_df = filt_model_tr_summary_df.sort_values(["gra_enc_l", "gra_enc_h", "seq_len", "max_tr_edge_acc", "max_val_edge_acc"], ascending=False)
display(filt_model_tr_summary_df)

In [None]:
mts_corr_model_log_dir = Path("./save_models/mts_corr_ad_model/deprecated/20230605/sp500_20082017_corr_ser_reg_std_corr_mat_hrchy_9_cluster_label_last-train_train/")
log_path_list1 = mts_corr_model_log_dir.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list2 = mts_corr_model_log_dir.glob("./*[archive][!deprecated][!.ipynb_checkpoints]*/**/train_logs/*[!.ipynb_checkpoints]*[.json]")
log_path_list3 = mts_corr_model_log_dir.glob("./**/train_logs/*[!.ipynb_checkpoints]*[.json]")

# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "tr_batch": 32, "gra_enc_l": 1, "gra_enc_h": 4, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gru_l": 1, "gru_h": 8})
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "gra_enc_l": 5, "gra_enc_h": 16, "filt_mode": "keep_strong", "graph_enc":"GineEncoder"}, plot_pic=True)
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list1, {"corr_info": "corr_s1_w10", "loss_fns": str(['MSELoss()', 'discr_loss'])}, plot_pic=True)
# model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list2, {"corr_info": "corr_s1_w10", "loss_fns": str(['MSELoss()'])}, plot_pic=True)
model_tr_summary_df = mts_corr_ad_tr_proc_est(log_path_list2, {}, plot_pic=True)
filt_model_tr_summary_df = model_tr_summary_df.loc[::, ["seq_len", "opt_lr", "opt_weight_decay", 'gra_enc_weight_l2_reg_lambda', "drop_pos", "drop_p", "gra_enc_aggr", "gra_enc_l", "gra_enc_h", "gru_l", "gru_h", "decoder", "max_tr_edge_acc", "max_val_edge_acc"]]
filt_model_tr_summary_df = filt_model_tr_summary_df.sort_values(["opt_weight_decay", "drop_p", "opt_lr", "gra_enc_aggr", "gra_enc_l", "gra_enc_h", "gru_l", "gru_h", "max_tr_edge_acc", "max_val_edge_acc"], ascending=False)
display(filt_model_tr_summary_df)

# Find the most differ graph

In [None]:
sys.path.append("/workspace/correlation-change-predict/ywt_library")
current_dir = Path(__file__).parent
data_config_path = current_dir/"../config/data_config.yaml"
with open(data_config_path) as f:
    data = dynamic_yaml.load(f)
    data_cfg = yaml.full_load(dynamic_yaml.dump(data))

# ## Data implement & output setting & testset setting
# data implement setting
data_implement = "SP500_20082017_CORR_SER_REG_CORR_MAT_HRCHY_11_CLUSTER"  # watch options by operate: logging.info(data_cfg["DATASETS"].keys())
# train set setting
train_items_setting = "-train_train"  # -train_train|-train_all
# setting of name of output files and pictures title
output_file_name = data_cfg["DATASETS"][data_implement]['OUTPUT_FILE_NAME_BASIS'] + train_items_setting
# setting of output files
logging.info(f"===== file_name basis:{output_file_name} =====")
graph_data_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}-graph_data"
graph_arr = np.load(graph_data_dir/f"corr_s1_w10_graph.npy")  # each graph consist of 66 node & 66^2 edges

stride = 12
train_arr = graph_arr[:int(len(graph_arr)*0.9)]
val_arr = graph_arr[int(len(graph_arr)*0.9):int(len(graph_arr)*0.95)]
test_arr = graph_arr[int(len(graph_arr)*0.95):]
train_diff_arr = train_arr[stride:] - train_arr[:-stride] # this is what I want
max_diff_ind = np.argmax(train_diff_arr.sum(axis=1).sum(axis=1))
logging.info(f"train_arr.shape: {train_arr.shape}")
logging.info(f"train_diff_arr.shape: {train_diff_arr.shape}")
logging.info(f"train_arr[0][0][:5]: \n{train_arr[0][0][:5]}")
logging.info(f"max_difference index of train_arr: {max_diff_ind}")
logging.info(f"train_diff_arr[{max_diff_ind}][0]: \n{train_diff_arr[max_diff_ind][0]}")
logging.info(f"train_arr[{max_diff_ind}][0]: \n{train_arr[max_diff_ind][0]}")
logging.info(f"train_arr[{max_diff_ind+stride}][0]: \n{train_arr[max_diff_ind+stride][0]}")