In [1]:
from tqdm import tqdm
from pathlib import Path
import warnings
import sys
import logging
from pprint import pformat

import pandas as pd
import numpy as np
import matplotlib as mpl
import dynamic_yaml
import yaml

sys.path.append("/workspace/correlation-change-predict/ywt_library")
import data_generation
from data_generation import data_gen_cfg, gen_corr_dist_mat, gen_corr_graph
from stl_decompn import stl_decompn
from corr_property import calc_corr_ser_property


with open('../config/data_config.yaml') as f:
    data = dynamic_yaml.load(f)
    data_cfg = yaml.full_load(dynamic_yaml.dump(data))

warnings.simplefilter("ignore")
logging.basicConfig(level=logging.INFO)
matplotlib_logger = logging.getLogger("matplotlib")
matplotlib_logger.setLevel(logging.ERROR)
mpl.rcParams[u'font.sans-serif'] = ['simhei']
mpl.rcParams['axes.unicode_minus'] = False
# logger_list = [logging.getLogger(name) for name in logging.root.manager.loggerDict]
# print(logger_list)

# %load_ext pycodestyle_magic
# %pycodestyle_on --ignore E501
logging.debug(pformat(data_cfg, indent=1, width=100, compact=True))
logging.info(pformat(data_gen_cfg, indent=1, width=100, compact=True))

INFO:root:{'CORR_STRIDE': 1, 'CORR_WINDOW': 10, 'DATA_DIV_STRIDE': 20, 'MAX_DATA_DIV_START_ADD': 0}


time: 905 ms (started: 2023-02-12 16:11:56 +00:00)


# Prepare data

## Data implement & output setting & testset setting

In [2]:
# setting of output files
save_corr_data = False
# data implement setting
data_implement = "SP500_20082017_CORR_SER_REG_CORR_MAT_HRCHY_11_CLUSTER"  # watch options by operate: print(data_cfg["DATASETS"].keys())
# data split period setting, only suit for only settings of Korean paper
data_split_setting = "-data_sp_test2"
# train set setting
train_items_setting = "-train_train"  # -train_train|-train_all
# Decide composition of graph_matrix
graph_mat_compo = "sim"
# setting of output files
save_corr_graph_arr = False

time: 453 µs (started: 2023-02-12 16:11:57 +00:00)


In [3]:
# data loading & implement setting
dataset_df = pd.read_csv(data_cfg["DATASETS"][data_implement]['FILE_PATH'])
dataset_df = dataset_df.set_index('Date')
all_set = list(dataset_df.columns)  # all data
train_set = data_cfg["DATASETS"][data_implement]['TRAIN_SET']
test_set = data_cfg['DATASETS'][data_implement]['TEST_SET'] if data_cfg['DATASETS'][data_implement].get('TEST_SET') else [p for p in all_set if p not in train_set]  # all data - train data
logging.info(f"===== len(train_set): {len(train_set)}, len(all_set): {len(all_set)}, len(test_set): {len(test_set)} =====")

# train items implement settings
items_implement = train_set if train_items_setting == "-train_train" else all_set
logging.info(f"===== len(train set): {len(items_implement)} =====")

# setting of name of output files and pictures title
output_file_name = data_cfg["DATASETS"][data_implement]['OUTPUT_FILE_NAME_BASIS'] + train_items_setting
logging.info(f"===== file_name basis:{output_file_name} =====")
display(dataset_df)

# output folder settings
corr_data_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}-corr_data"
res_dir = Path(data_cfg["DIRS"]["PIPELINE_DATA_DIR"])/f"{output_file_name}-graph_data"
corr_data_dir.mkdir(parents=True, exist_ok=True)
res_dir.mkdir(parents=True, exist_ok=True)

INFO:root:===== len(train_set): 66, len(all_set): 97, len(test_set): 31 =====
INFO:root:===== len(train set): 66 =====
INFO:root:===== file_name basis:sp500_20082017_corr_ser_reg_corr_mat_hrchy_11_cluster-train_train =====


Unnamed: 0_level_0,FE,KEY,ROK,WDC,CLX,MSCI,RSG,FIS,LEG,ISRG,...,NFLX,TDG,HRB,AON,WU,MON,CL,DISCA,PWR,AZO
Date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2008-01-02,44.431971,18.352796,52.709816,25.208257,47.021449,32.562262,23.065471,19.381723,11.169138,107.983333,...,3.764286,24.191702,13.312927,40.964039,18.549406,94.930829,30.588703,11.310171,25.98,116.21
2008-01-03,44.920234,18.248153,52.804675,25.014347,46.690260,29.618672,22.945025,19.487476,11.063769,107.250000,...,3.724286,24.400714,12.697714,40.646488,18.726604,102.978701,30.806233,11.492667,25.47,113.72
2008-01-04,45.933845,17.580047,50.662449,22.590476,46.690260,29.003019,22.350321,19.165409,10.747661,101.666667,...,3.515714,24.416791,12.289956,40.108416,17.671470,101.880103,31.158237,11.150487,23.86,110.58
2008-01-07,47.948704,17.893976,49.958914,22.281984,46.925772,27.069484,22.779411,18.559730,10.708148,99.993333,...,3.554286,24.164906,12.475951,39.905537,17.607035,102.569920,31.628893,11.159612,21.96,112.65
2008-01-08,46.582802,17.314414,48.172408,20.765963,46.447388,25.616928,22.689076,18.295347,10.536923,90.576667,...,3.328571,22.846528,12.161191,39.261614,17.212367,102.212237,31.632848,11.036428,22.73,108.38
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2017-12-22,30.480000,20.440000,194.570000,80.680000,149.520000,128.330000,66.930000,94.320000,47.060000,363.160000,...,189.940000,270.650000,26.230000,133.350000,19.080000,116.280000,75.150000,23.730000,39.41,710.09
2017-12-26,30.320000,20.330000,194.530000,80.000000,149.690000,125.600000,66.770000,94.130000,47.540000,365.830000,...,187.760000,273.110000,26.060000,133.300000,18.890000,116.020000,75.480000,23.570000,39.50,714.48
2017-12-27,30.320000,20.240000,195.990000,80.960000,149.290000,125.470000,66.930000,94.000000,47.550000,368.060000,...,186.240000,273.810000,26.450000,133.920000,18.800000,116.360000,75.370000,23.050000,39.48,707.00
2017-12-28,30.450000,20.350000,197.530000,80.640000,149.000000,126.380000,67.410000,94.280000,47.530000,368.870000,...,192.710000,275.770000,26.570000,134.520000,19.070000,116.270000,75.140000,22.550000,39.33,718.38


time: 41 ms (started: 2023-02-12 16:11:57 +00:00)


## Load or Create Correlation Data

In [4]:
# DEFAULT SETTING: data_gen_cfg["DATA_DIV_STRIDE"] == 20, data_gen_cfg["CORR_WINDOW"]==100, data_gen_cfg["CORR_STRIDE"]==100
train_df_path = corr_data_dir/f"{output_file_name}-corr_train.csv"
dev_df_path = corr_data_dir/f"{output_file_name}-corr_dev.csv"
test1_df_path = corr_data_dir/f"{output_file_name}-corr_test1.csv"
test2_df_path = corr_data_dir/f"{output_file_name}-corr_test2.csv"
all_corr_df_paths = dict(zip(["train_df", "dev_df", "test1_df", "test2_df"],
                             [train_df_path, dev_df_path, test1_df_path, test2_df_path]))
if all([df_path.exists() for df_path in all_corr_df_paths.values()]):
    corr_datasets = [pd.read_csv(df_path, index_col=["items"]) for df_path in all_corr_df_paths.values()]
else:
    # corr_datasets = data_generation.gen_train_data(items_implement, raw_data_df=dataset_df, corr_df_paths=all_corr_df_paths, corr_ser_len_max=corr_ser_len_max, corr_ind=corr_ind, max_data_div_start_add=max_data_div_start_add, save_file=save_corr_data)
    corr_datasets = data_generation.gen_train_data(items_implement, raw_data_df=dataset_df, corr_df_paths=all_corr_df_paths, save_file=save_corr_data)

if data_split_setting == "-data_sp_test2":
    corr_dataset = corr_datasets[3]
    display(corr_dataset.head())

2145it [00:19, 109.61it/s]


Date,2008-01-18,2008-01-22,2008-01-23,2008-01-24,2008-01-25,2008-01-28,2008-01-29,2008-01-30,2008-01-31,2008-02-01,...,2017-12-01,2017-12-04,2017-12-05,2017-12-06,2017-12-07,2017-12-08,2017-12-11,2017-12-12,2017-12-13,2017-12-14
items,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ABT & ADI_0,-0.784505,-0.240349,-0.363866,-0.65393,-0.602313,-0.575414,-0.596802,-0.530457,-0.552757,-0.146133,...,-0.519675,0.052708,0.608385,0.632594,0.628044,0.642353,0.541167,0.370197,0.358859,0.028016
ABT & ADS_0,-0.358393,0.115923,0.243212,0.073157,-0.105883,0.198989,0.29602,0.351901,0.302613,0.43293,...,0.203268,-0.292588,-0.413235,-0.325598,-0.323925,-0.217494,0.057787,0.568626,0.620431,0.337456
ABT & AFL_0,0.236135,0.506458,0.537641,0.562178,0.607391,0.474709,0.335618,0.295664,-0.060492,-0.327804,...,0.405225,-0.251836,-0.568895,-0.571712,-0.644675,-0.65883,-0.617539,-0.236051,-0.157858,0.076534
ABT & AMP_0,0.031137,0.570727,0.623278,0.619504,0.59382,0.471022,0.152043,-0.228536,-0.565895,-0.267176,...,0.122752,-0.544952,-0.710082,-0.739657,-0.811352,-0.791884,-0.75369,-0.51325,-0.48484,-0.270711
ABT & AMT_0,0.041613,0.729295,0.817976,0.656269,0.691261,0.653618,0.381787,0.20572,-0.014559,-0.228932,...,0.004459,0.626742,0.78392,0.785602,0.728067,0.702016,0.61734,0.513632,0.471368,0.365495


time: 19.6 s (started: 2023-02-12 16:11:57 +00:00)


## concate correlation matrix across time

In [6]:
corr_dist_mat_df = dataset_df.loc[::,train_set]
gen_corr_graph(corr_dataset, corr_dist_mat_df, save_dir=res_dir, save_file=save_corr_graph_arr, show_mat_i_info=12)


INFO:root:correlation graph.shape:(66, 66)
INFO:root:number of correlation graph:2497
INFO:root:
Min of corr_mat:items
ABT   -0.495273
ADI   -0.353823
ADS   -0.647468
AFL   -0.794561
AMP   -0.366663
         ...   
WHR   -0.897595
WU    -0.570998
WYN   -0.299710
XEC   -0.460472
XRX   -0.647468
Length: 66, dtype: float32
INFO:root:
(66, 66)
INFO:root:
items       ABT       ADI       ADS       AFL       AMP       AMT      ANTM  \
items                                                                         
ABT    1.000000  0.268177  0.046521 -0.495273  0.376682  0.509588  0.774907   
ADI    0.268177  1.000000  0.224511  0.490118  0.538007  0.617613  0.162379   
ADS    0.046521  0.224511  1.000000 -0.089325 -0.366663 -0.239311 -0.250221   
AFL   -0.495273  0.490118 -0.089325  1.000000 -0.071470  0.158193 -0.626531   
AMP    0.376682  0.538007 -0.366663 -0.071470  1.000000  0.599940  0.701920   

items       AON       AZO       BWA  ...      TROW       TWX       UNP  \
items              

time: 12.2 s (started: 2023-02-12 16:15:14 +00:00)
