In [None]:
import sys
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import dynamic_yaml
import yaml

sys.path.append("/workspace/correlation-change-predict/utils")
from utils import convert_str_bins_list, split_and_norm_data

data_config_path = Path("../config/data_config.yaml")
with open(data_config_path) as f:
    data_cfg_yaml = dynamic_yaml.load(f)
    data_cfg = yaml.full_load(dynamic_yaml.dump(data_cfg_yaml))

In [None]:
batch_size = 64
corr_type = "pearson"
s_l = 1
w_l = 50
filt_mode = None
quan_discrete_bins = None
custom_discrete_bins = None
graph_nodes_v_mode = None
etl_items_setting = "-train_train"  # -train_train|-train_all
target_mats_path = "pearson/custom_discretize_graph_adj_mat/bins_-10_-025_025_10"
output_file_name = "sp500_20112015_corr_ser_reg_std_corr_mat_large_filtered_hrchy_10_cluster_label_last_v2_negative_filtered"+etl_items_setting
can_count_square_graph = False
can_count_upper_triangle = True
count_one_edge_idx = None # 3*3 matrix as an example, idx:0 for (A & B), idx:1 for (A & C), idx:2 for (B & C)
assert (bool(filt_mode) != bool(quan_discrete_bins)) or (filt_mode is None and quan_discrete_bins is None), "filt_mode and quan_discrete_bins must be both not input or one input"
assert (can_count_square_graph+can_count_upper_triangle+bool(count_one_edge_idx is not None)) == 1, "can_count_square_graph, can_count_upper_triangle and count_one_edge_idx, only one of them can be True"

In [None]:
if filt_mode:
    graph_adj_mode_dir = f"filtered_graph_adj_mat/{filt_mode}-quan{str(filt_quan).replace('.', '')}"
elif quan_discrete_bins:
    graph_adj_mode_dir = f"quan_discretize_graph_adj_mat/bins{quan_discrete_bins}"
elif custom_discrete_bins:
    graph_adj_mode_dir = f"custom_discretize_graph_adj_mat/bins_{'_'.join((str(f) for f in custom_discrete_bins)).replace('.', '')}"
else:
    graph_adj_mode_dir = "graph_adj_mat"
graph_adj_mat_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}/{corr_type}/{graph_adj_mode_dir}"
graph_node_mat_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}/graph_node_mat"
target_mat_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}/{target_mats_path}"

gra_edges_data_mats = np.load(graph_adj_mat_dir/f"corr_s{s_l}_w{w_l}_adj_mat.npy")
gra_nodes_data_mats = np.load(graph_node_mat_dir/f"{graph_nodes_v_mode}_s{s_l}_w{w_l}_nodes_mat.npy") if graph_nodes_v_mode else np.ones((gra_edges_data_mats.shape[0], 1, gra_edges_data_mats.shape[2]))
target_mats = np.load(target_mat_dir/f"corr_s{s_l}_w{w_l}_adj_mat.npy") if target_mats_path else None
norm_train_dataset, norm_val_dataset, norm_test_dataset, scaler = split_and_norm_data(edges_mats=gra_edges_data_mats, nodes_mats=gra_nodes_data_mats, target_mats=target_mats, batch_size= batch_size)

In [None]:
print("================ edges ==================")
print(norm_train_dataset['edges'][:3])
print("================ nodes ==================")
print(norm_train_dataset['nodes'][:3])
print("================ target ==================")
print(norm_train_dataset['target'][:3])

In [None]:
obs_target = {"train": None, "val": None}
obs_graphs_dict = {"train": norm_train_dataset['target'], "val": norm_val_dataset['target']}
num_nodes = norm_train_dataset['target'][0].shape[0]
graph_size = norm_train_dataset['target'][0].size
upper_triangle_idxs = np.triu_indices(num_nodes, 1)
for split, graph_adj_mats in obs_graphs_dict.items():
    for i, graph_adj_t in enumerate(graph_adj_mats):
        if can_count_upper_triangle:
            obs_target[split] = graph_adj_t[upper_triangle_idxs].reshape(1, -1) if i == 0 else np.concatenate([obs_target[split], graph_adj_t[upper_triangle_idxs].reshape(1, -1)])
        elif count_one_edge_idx is not None:
            obs_target[split] = graph_adj_t[upper_triangle_idxs][count_one_edge_idx].reshape(1, -1) if i == 0 else np.concatenate([obs_target[split], graph_adj_t[upper_triangle_idxs][count_one_edge_idx].reshape(1, -1)])
        else:
            can_count_square_graph = True
            break
    if can_count_square_graph:
        obs_target[split] =  graph_adj_mats
print(f"obs_target[train].shape:{obs_target['train'].shape}, obs_target[val].shape:{obs_target['val'].shape}")

In [None]:
tr_labels, tr_labels_freq_counts = np.unique(obs_target["train"], return_counts=True)
val_labels, val_labels_freq_counts = np.unique(obs_target["val"], return_counts=True)
print(f"implement dataset:{output_file_name}")
tr_val_info = {"train": {"dataset_target": norm_train_dataset['target'],
                         "freq_info": dict(zip(tr_labels, tr_labels_freq_counts))},
               "val": {"dataset_target": norm_val_dataset['target'],
                       "freq_info": dict(zip(val_labels, val_labels_freq_counts))}}
for data_split in tr_val_info:
    data_info = tr_val_info[data_split]
    print(f"norm_{data_split}_dataset[target]:\n  shape: {data_info['dataset_target'].shape}\n  size: {data_info['dataset_target'].size}")
    sum_num_freq_each_label = 0
    print(f"for obs_target:")
    for label, freq in data_info['freq_info'].items():
        print(f"  {data_split} label :{label}, frequency: {freq}")
        sum_num_freq_each_label += freq
    print(f"  {data_split} sum_num_freq_each_label:{sum_num_freq_each_label}")
    print("-"*30)

In [None]:
if can_count_square_graph:
    target_retrieve_setting = "square"
elif can_count_upper_triangle:
    target_retrieve_setting = "upper_triangle"
elif count_one_edge_idx is not None:
    target_retrieve_setting = f"edge_idx_{count_one_edge_idx}"
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(16,9))
colors_labels_map = {"-1.0": "lime", "0.0": "darkorange", "1.0": "dodgerblue"}
axes[0].pie(tr_labels_freq_counts, labels=tr_labels, autopct='%1.1f%%', textprops={'fontsize': 24}, colors=[colors_labels_map[str(label)] for label in tr_labels])
axes[0].set_title("Train", fontsize=32)
axes[1].pie(val_labels_freq_counts, labels=val_labels, autopct='%1.1f%%', textprops={'fontsize': 24}, colors=[colors_labels_map[str(label)] for label in val_labels])
axes[1].set_title("Validation", fontsize=32)
#fig.suptitle(f'Irrelevant keep({num_nodes} company) with {target_retrieve_setting}', fontsize=40)
fig.suptitle(f'Positive_Negative keep({num_nodes} company)', fontsize=40)

plt.show()
plt.close()