In [1]:
from tqdm import tqdm
from pathlib import Path
import warnings
import sys
import logging
from pprint import pformat

import pandas as pd
import numpy as np
import matplotlib as mpl
import dynamic_yaml
import yaml

sys.path.append("/workspace/correlation-change-predict/ywt_library")
import data_generation
from data_generation import data_gen_cfg, gen_corr_dist_mat
from stl_decompn import stl_decompn
from corr_property import calc_corr_ser_property


with open('../config/data_config.yaml') as f:
    data = dynamic_yaml.load(f)
    data_cfg = yaml.full_load(dynamic_yaml.dump(data))

warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.ERROR)
mpl.rcParams[u'font.sans-serif'] = ['simhei']
mpl.rcParams['axes.unicode_minus'] = False
# logger_list = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# print(logger_list)

# %load_ext pycodestyle_magic
# %pycodestyle_on --ignore E501
logging.debug(pformat(data_cfg, indent=1, width=100, compact=True))
logging.info(pformat(data_gen_cfg, indent=1, width=100, compact=True))

INFO:root:{'CORR_STRIDE': 5, 'CORR_WINDOW': 5, 'DATA_DIV_STRIDE': 20, 'MAX_DATA_DIV_START_ADD': 0}


time: 876 ms (started: 2023-01-28 18:21:47 +00:00)


# Prepare data

## Data implement & output setting & testset setting

In [2]:
# setting of output files
save_corr_data = False
# data implement setting
data_implement = "SP500_20082017_CORR_SER_REG_CORR_MAT_HRCHY_11_CLUSTER"  # watch options by operate: print(data_cfg["DATASETS"].keys())
# data split period setting, only suit for only settings of Korean paper
data_split_setting = "-data_sp_test2"
# train set setting
train_items_setting = "-train_train"  # -train_train|-train_all
# Decide format of corr_distance_matrix
dist_mat_format = True

time: 413 µs (started: 2023-01-28 18:21:47 +00:00)


In [3]:
# data loading & implement setting
dataset_df = pd.read_csv(data_cfg["DATASETS"][data_implement]['FILE_PATH'])
dataset_df = dataset_df.set_index('Date')
all_set = list(dataset_df.columns)  # all data
train_set = data_cfg["DATASETS"][data_implement]['TRAIN_SET']
test_set = data_cfg['DATASETS'][data_implement]['TEST_SET'] if data_cfg['DATASETS'][data_implement].get('TEST_SET') else [p for p in all_set if p not in train_set]  # all data - train data
logging.info(f"===== len(train_set): {len(train_set)}, len(all_set): {len(all_set)}, len(test_set): {len(test_set)} =====")

# train items implement settings
items_implement = train_set if train_items_setting == "-train_train" else all_set
logging.info(f"===== len(train set): {len(items_implement)} =====")

# setting of name of output files and pictures title
output_file_name = data_cfg["DATASETS"][data_implement]['OUTPUT_FILE_NAME_BASIS'] + train_items_setting
logging.info(f"===== file_name basis:{output_file_name} =====")
display(dataset_df)

# output folder settings
corr_data_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}-corr_data"
res_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}-graph_data"
corr_data_dir.mkdir(parents=True, exist_ok=True)
res_dir.mkdir(parents=True, exist_ok=True)

INFO:root:===== len(train_set): 66, len(all_set): 97, len(test_set): 31 =====
INFO:root:===== len(train set): 66 =====
INFO:root:===== file_name basis:sp500_20082017_corr_ser_reg_corr_mat_hrchy_11_cluster-train_train =====


Unnamed: 0_level_0,FE,KEY,ROK,WDC,CLX,MSCI,RSG,FIS,LEG,ISRG,...,NFLX,TDG,HRB,AON,WU,MON,CL,DISCA,PWR,AZO
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2008-01-02,44.431971,18.352796,52.709816,25.208257,47.021449,32.562262,23.065471,19.381723,11.169138,107.983333,...,3.764286,24.191702,13.312927,40.964039,18.549406,94.930829,30.588703,11.310171,25.98,116.21
2008-01-03,44.920234,18.248153,52.804675,25.014347,46.690260,29.618672,22.945025,19.487476,11.063769,107.250000,...,3.724286,24.400714,12.697714,40.646488,18.726604,102.978701,30.806233,11.492667,25.47,113.72
2008-01-04,45.933845,17.580047,50.662449,22.590476,46.690260,29.003019,22.350321,19.165409,10.747661,101.666667,...,3.515714,24.416791,12.289956,40.108416,17.671470,101.880103,31.158237,11.150487,23.86,110.58
2008-01-07,47.948704,17.893976,49.958914,22.281984,46.925772,27.069484,22.779411,18.559730,10.708148,99.993333,...,3.554286,24.164906,12.475951,39.905537,17.607035,102.569920,31.628893,11.159612,21.96,112.65
2008-01-08,46.582802,17.314414,48.172408,20.765963,46.447388,25.616928,22.689076,18.295347,10.536923,90.576667,...,3.328571,22.846528,12.161191,39.261614,17.212367,102.212237,31.632848,11.036428,22.73,108.38
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-12-22,30.480000,20.440000,194.570000,80.680000,149.520000,128.330000,66.930000,94.320000,47.060000,363.160000,...,189.940000,270.650000,26.230000,133.350000,19.080000,116.280000,75.150000,23.730000,39.41,710.09
2017-12-26,30.320000,20.330000,194.530000,80.000000,149.690000,125.600000,66.770000,94.130000,47.540000,365.830000,...,187.760000,273.110000,26.060000,133.300000,18.890000,116.020000,75.480000,23.570000,39.50,714.48
2017-12-27,30.320000,20.240000,195.990000,80.960000,149.290000,125.470000,66.930000,94.000000,47.550000,368.060000,...,186.240000,273.810000,26.450000,133.920000,18.800000,116.360000,75.370000,23.050000,39.48,707.00
2017-12-28,30.450000,20.350000,197.530000,80.640000,149.000000,126.380000,67.410000,94.280000,47.530000,368.870000,...,192.710000,275.770000,26.570000,134.520000,19.070000,116.270000,75.140000,22.550000,39.33,718.38


time: 39.9 ms (started: 2023-01-28 18:21:47 +00:00)


## Load or Create Correlation Data

In [4]:
# DEFAULT SETTING: data_gen_cfg["DATA_DIV_STRIDE"] == 20, data_gen_cfg["CORR_WINDOW"]==100, data_gen_cfg["CORR_STRIDE"]==100
# data_length = int(len(dataset_df)/data_gen_cfg["CORR_WINDOW"])*data_gen_cfg["CORR_WINDOW"]
# corr_ser_len_max = int((data_length-data_gen_cfg["CORR_WINDOW"])/data_gen_cfg["CORR_STRIDE"])
# max_data_div_start_add = 0  # In the Korea paper, each pair has 5 corr_series(due to diversifing train data).
#                             # BUT we only need to take one, so take 0 as arg, add 20 for each corr_series
# corr_ind = []

# data_end_init = corr_ser_len_max * data_gen_cfg["CORR_STRIDE"]
# for i in range(0, max_data_div_start_add+1, data_gen_cfg["DATA_DIV_STRIDE"]):
#     corr_ind.extend(list(range(data_gen_cfg["CORR_WINDOW"]-1+i, data_end_init+bool(i)*data_gen_cfg["CORR_STRIDE"], data_gen_cfg["CORR_STRIDE"])))  # only suit for settings of paper
# print(corr_ind)

train_df_path = corr_data_dir/f"{output_file_name}-corr_train.csv"
dev_df_path = corr_data_dir/f"{output_file_name}-corr_dev.csv"
test1_df_path = corr_data_dir/f"{output_file_name}-corr_test1.csv"
test2_df_path = corr_data_dir/f"{output_file_name}-corr_test2.csv"
all_corr_df_paths = dict(zip(["train_df", "dev_df", "test1_df", "test2_df"],
                             [train_df_path, dev_df_path, test1_df_path, test2_df_path]))
if all([df_path.exists() for df_path in all_corr_df_paths.values()]):
    corr_datasets = [pd.read_csv(df_path, index_col=["items"]) for df_path in all_corr_df_paths.values()]
else:
    # corr_datasets = data_generation.gen_train_data(items_implement, raw_data_df=dataset_df, corr_df_paths=all_corr_df_paths, corr_ser_len_max=corr_ser_len_max, corr_ind=corr_ind, max_data_div_start_add=max_data_div_start_add, save_file=save_corr_data)
    corr_datasets = data_generation.gen_train_data(items_implement, raw_data_df=dataset_df, corr_df_paths=all_corr_df_paths, save_file=save_corr_data)

if data_split_setting == "-data_sp_test2":
    corr_dataset = corr_datasets[3]
    display(corr_dataset.head())

2145it [00:07, 288.54it/s]


Date,2008-01-30,2008-02-06,2008-02-13,2008-02-21,2008-02-28,2008-03-06,2008-03-13,2008-03-20,2008-03-28,2008-04-04,...,2017-10-12,2017-10-19,2017-10-26,2017-11-02,2017-11-09,2017-11-16,2017-11-24,2017-12-01,2017-12-08,2017-12-15
items,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ABT & ADI_0,0.935536,0.014252,-0.705454,-0.86195,0.796494,0.094265,0.866668,0.409627,0.905134,-0.335217,...,-0.203856,-0.284118,-0.918336,0.199069,-0.291867,0.731095,-0.746275,0.390113,-0.574039,0.378756
ABT & ADS_0,0.092408,0.448559,-0.973413,0.681387,0.247138,0.828805,0.871309,0.551438,0.649894,-0.480391,...,0.088078,0.347341,0.485302,0.8878,-0.446216,-0.144184,-0.617587,0.067992,0.19397,0.723458
ABT & AFL_0,0.643285,-0.37846,-0.446946,-0.121289,0.289814,-0.172447,0.469206,0.921605,-0.149132,0.118007,...,0.008749,0.194807,-0.174537,-0.303427,0.712038,0.463823,0.582008,-0.078458,-0.132805,0.739538
ABT & AMP_0,-0.037215,-0.147084,0.073672,0.237572,-0.050944,0.965579,0.494077,0.890634,0.924817,-0.523432,...,-0.034917,-0.614384,-0.475963,0.6991,0.485259,-0.458973,-0.149781,-0.318449,-0.651522,0.830477
ABT & AMT_0,0.530335,0.53282,-0.793227,0.497505,0.362947,0.871597,0.718157,0.898772,0.364504,-0.651521,...,-0.079929,-0.885545,0.788655,-0.695732,-0.393646,-0.492134,0.51704,-0.173881,-0.419797,-0.591587


time: 7.45 s (started: 2023-01-28 18:21:47 +00:00)


In [5]:
corr_spatial = corr_dataset.iloc[::,495]
corr_dist_mat_df = dataset_df.loc[::,train_set]
distance_mat = gen_corr_dist_mat(corr_spatial, corr_dist_mat_df, output_similarity_mat=dist_mat_format)
# test
# test_stock_tickers = ["ED", "BAC", "XEL", "MA"]
# test_distance_mat = distance_mat.loc[test_stock_tickers, test_stock_tickers]
# display(test_distance_mat)  # comlpete: (ED, BAC), (XEL), (MA) -> (ED, BAC), (XEL, MA)  -> (ED, BAC, XEL, MA)
#                             # single: (ED, BAC), (XEL), (MA) -> (ED, BAC, XEL), (MA)  -> (ED, BAC, XEL, MA)
logging.info(f"Min of distance_mat:{distance_mat.min()}")
display(distance_mat.shape)
display(distance_mat.head())

INFO:root:Min of distance_mat:items
ABT   -0.875093
ADI   -0.952448
ADS   -0.966402
AFL   -0.996633
AMP   -0.988824
         ...   
WHR   -0.780714
WU    -0.926991
WYN   -0.932910
XEC   -0.932910
XRX   -0.772948
Length: 66, dtype: float32


(66, 66)

items,ABT,ADI,ADS,AFL,AMP,AMT,ANTM,AON,AZO,BWA,...,TROW,TWX,UNP,URI,WDC,WHR,WU,WYN,XEC,XRX
items,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ABT,1.0,-0.746275,-0.617587,0.582008,-0.149781,0.51704,0.29512,-0.588272,-0.02754,0.665695,...,0.865344,0.821394,0.776663,0.761971,0.844786,0.942031,-0.072622,0.312561,0.040625,-0.005414
ADI,-0.746275,1.0,0.852055,-0.315555,0.484425,0.034567,-0.707926,0.834696,-0.198707,-0.573085,...,-0.698962,-0.952448,-0.741365,-0.50443,-0.904011,-0.531834,0.601285,0.043479,-0.281865,0.568819
ADS,-0.617587,0.852055,1.0,-0.587822,0.050095,-0.050353,-0.364372,0.960042,-0.561797,-0.833379,...,-0.79804,-0.812085,-0.923908,-0.647402,-0.917338,-0.428912,0.778329,0.009871,-0.172824,0.330816
AFL,0.582008,-0.315555,-0.587822,1.0,0.654185,0.281571,-0.445007,-0.407367,0.673438,0.935669,...,0.881209,0.513804,0.810321,0.962152,0.688771,0.693784,-0.342426,0.708187,-0.566051,0.160644
AMP,-0.149781,0.484425,0.050095,0.654185,1.0,0.198067,-0.946773,0.195084,0.578682,0.433342,...,0.228007,-0.29971,0.186106,0.440425,-0.077427,0.086514,0.011071,0.581684,-0.683531,0.542464


time: 19.1 ms (started: 2023-01-28 18:21:55 +00:00)


## concate correlation matrix across time

In [6]:
tmp_graph_list = []
for i in range(corr_dataset.shape[1]):
    corr_spatial = corr_dataset.iloc[::,i]
    corr_dist_mat_df = dataset_df.loc[::,train_set]
    distance_arr = gen_corr_dist_mat(corr_spatial, corr_dist_mat_df, output_similarity_mat=dist_mat_format).to_numpy()
    tmp_graph_list.append(distance_arr)

graph_list = np.stack(tmp_graph_list, axis=0)
np.save(res_dir/f"corr_calc_reg-corr_graph", graph_list)

time: 2.75 s (started: 2023-01-28 18:21:55 +00:00)
