In [None]:
from pathlib import Path
from functools import partial
import json
import sys
import traceback

import numpy as np
import pandas as pd

sys.path.append("/workspace/multivariate-correlation-anomaly-detection/")
from utils.plot_utils import plot_gru_tr_process
from utils.log_utils import Log

JPY_LOGGER = Log(df_max_rows=50).init_logger(logger_name="ywt_jupyter")

In [None]:
def check_type(x, type_info):
    return isinstance(x, type_info)


def gru_class_tr_proc_est(log_path_list: list, condition_dict: dict, regexp_res: bool = False,  plot_pic:bool = True):
    try:
        df = pd.DataFrame()
        for log_path in log_path_list:
            with open(log_path, "r") as source:
                log_dict = json.load(source)

            for k in log_dict.keys():
                locals()[k] = log_dict[k]
            corr_info = str(next(filter(lambda p: p.startswith("corr"), log_path.parts)))
            min_tr_loss = min(locals()["tr_loss_history"])
            min_val_loss = min(locals()["val_loss_history"])
            max_tr_edge_acc = max(locals()["tr_edge_acc_history"])
            max_val_edge_acc = max(locals()["val_edge_acc_history"])
            model_struct_str = log_dict.get('model_structure')
            record_fields = list(log_dict.keys()) + ["corr_info", "min_tr_loss", "min_val_loss", "max_tr_edge_acc", "max_val_edge_acc"]
            assert not(set(condition_dict.keys()) - set(record_fields)), "one of condition_dict.keys() doesn't match the record_fields if mts_corr_ad_est()"
            est_values_dict = {k:v for k, v in locals().items() if k in record_fields}
            filtered_dict = dict(filter(lambda x: est_values_dict[x[0]] == x[1], condition_dict.items()))
            if filtered_dict == condition_dict:
                main_title_str = (f"{locals().get('corr_info')} with batch_size({locals().get('batch_size')}) input to\n"
                                  f"GRU with gru_l{locals().get('gru_l')}-gru_h{locals().get('gru_h')} "
                                  f"and drop: {locals().get('drop_p')} and loss_fns:{locals().get('loss_fns')}\n"
                                  f"min val-loss:{locals().get('min_val_loss'):8f} min tr-loss:{locals().get('min_tr_loss'):8f}")
                JPY_LOGGER.info(f"file_name:{log_path.parts[-1]}")
                JPY_LOGGER.info(f"file_path:{log_path.parts[2:-2]}")
                JPY_LOGGER.info(f"main_title_str:\n{main_title_str}")
                JPY_LOGGER.info("="*30)
                comparison_dict = dict(filter(lambda x: x[0] in record_fields, locals().items()))
                df = pd.concat([df, pd.DataFrame([comparison_dict])])
                if plot_pic:
                    plot_gru_tr_process(main_title=main_title_str, model_struct=model_struct_str, metrics_history={k:log_dict[k] for k in record_fields if "history" in k},
                                        best_epoch=locals()['best_val_epoch'])
            else:
                continue
        else:
            df = df.reindex(["corr_info", "num_train_data", "num_val_data", "seq_len", "epochs", "batch_size", "tr_batches_per_epoch", "val_batches_per_epoch", "model_input_cus_bins", "input_feature_idx",
                             "loss_fns", "custom_indices_loss_indices", "loss_weight","opt_lr", "opt_weight_decay", "optimizer", "opt_scheduler", "metric_fn", "custom_indices_metric_indices", "drop_pos", "drop_p", "gru_l",
                             "gru_h", "output_type", "target_data_bins", "tol_edge_acc_loss_atol", 'best_val_epoch', "min_tr_loss", 
                             "max_tr_edge_acc", "min_val_loss", "max_val_edge_acc"], axis=1)
            columns_containing_lists = df.where(np.vectorize(partial(check_type, type_info=list))).dropna(thresh=1, axis=1).columns
            columns_containing_dicts = df.where(np.vectorize(partial(check_type, type_info=dict))).dropna(thresh=1, axis=1).columns
            columns_containing_unhashable = columns_containing_lists.append(columns_containing_dicts)
            columns_containing_hashable = df.columns.difference(columns_containing_unhashable)
            independent_variables_columns = df.loc[::, columns_containing_hashable].nunique()[df.loc[::, columns_containing_hashable].nunique() > 1].index
            control_variables_columns = df.loc[::, columns_containing_hashable].nunique().index.difference(independent_variables_columns)
            for col in df.loc[::, columns_containing_unhashable]:
                col_ser = df.loc[::, col].apply(lambda x:str(x))
                df.loc[::, col] = col_ser
                if len(np.unique(col_ser)) > 1:
                    independent_variables_columns = independent_variables_columns.append(pd.Index([col]))
                else:
                    control_variables_columns = control_variables_columns.append(pd.Index([col]))
            df = df.sort_values(["batch_size", "seq_len", "gru_l", "gru_h", "drop_p"], ascending=False)
            df = df.sort_values(["input_feature_idx"], ascending=True)
            model_tr_summary_df = df.reset_index(drop=True)
            independent_variables_tr_summary_df = model_tr_summary_df.loc[::, independent_variables_columns].sort_index(axis=1)
            control_variables_tr_summary_df = model_tr_summary_df.loc[0, control_variables_columns].sort_index(axis=0)
            model_tr_summary_df.style.set_caption('Info of GRU_CLASS model with different hyperparameters')
            pd.options.display.float_format = '{:.6f}'.format
            pd.set_option('display.max_columns', None)
            pd.set_option('display.max_colwidth', 80)
            display(model_tr_summary_df)
            display(independent_variables_tr_summary_df)
            display(control_variables_tr_summary_df)
    except Exception as e:
        error_class = e.__class__.__name__ #⬞取得錯誤類型
        detail = e.args[0]  #⬞取得詳細內容
        cl, exc, tb = sys.exc_info() #⬞取得Call⬞Stack
        last_call_stack = traceback.extract_tb(tb)[-1] #⬞取得Call⬞Stack的最後一筆資料↵
        file_name = last_call_stack[0] #⬞取得發生的檔案名稱↵
        line_num = last_call_stack[1] #⬞取得發生的行號↵
        func_name = last_call_stack[2] #⬞取得發生的函數名稱
        err_msg = "File \"{}\", line {}, in {}: [{}] {}".format(file_name, line_num, func_name, error_class, detail)
        JPY_LOGGER.error(f"file:{log_path.parts[-1]}, path:{log_path}")
        JPY_LOGGER.error(f"===\n{err_msg}")
        JPY_LOGGER.error(f"===\n{traceback.extract_tb(tb)}")

    return model_tr_summary_df, independent_variables_tr_summary_df, control_variables_tr_summary_df

In [None]:
baseline_gru_log_dir = Path("../save_models/gru_corr_class_custom_features/sp500_20112015-train_train/pearson/")
log_path_list1 = baseline_gru_log_dir.glob("./*[!deprecated][!archive][!.ipynb_checkpoints]*/train_logs/*[!.ipynb_checkpoints]*[.json]")
model_tr_summary_df, independent_variables_tr_summary_df, control_variables_tr_summary_df = gru_class_tr_proc_est(log_path_list=log_path_list1, condition_dict={}, plot_pic=False)

In [None]:
filtered_independent_variables_tr_summary_df1 = independent_variables_tr_summary_df.where(independent_variables_tr_summary_df['max_val_edge_acc']<0.6).dropna()
display(filtered_independent_variables_tr_summary_df1)
filtered_independent_variables_tr_summary_df2 = independent_variables_tr_summary_df.where(independent_variables_tr_summary_df['input_feature_idx']=='[158]').dropna()
display(filtered_independent_variables_tr_summary_df2)