In [1]:
import os
import pandas as pd
import logging
import datetime
import joblib
import mlflow.keras
from tensorflow.keras.models import load_model
from keras.callbacks import ModelCheckpoint, EarlyStopping, CSVLogger
from ml_investing_wne.xtb.xAPIConnector import APIClient, APIStreamClient, loginCommand
from ml_investing_wne.data_engineering.prepare_dataset import prepare_processed_dataset
import ml_investing_wne.config as config
from ml_investing_wne.train_test_val_split import train_test_val_split
from ml_investing_wne.helper import confusion_matrix_plot, compute_profitability_classes, check_hours
import importlib

logger = logging.getLogger()
logger.setLevel(logging.INFO)

2022-09-25 13:25:35.964986: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-09-25 13:25:35.965030: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [2]:
pd.options.display.max_columns = None

In [3]:
symbol = 'US500'
auxiliary_symbols = ['VIX', 'TNOTE']
period = 1440

build_model = getattr(importlib.import_module('ml_investing_wne.cnn.{}'.format(config.model)),
                      'build_model')

In [69]:
df = pd.read_csv(os.path.join(config.raw_data_path_xtb, symbol + '_' + str(period) + '.csv' ), parse_dates=['datetime'])

In [70]:
df.head()

Unnamed: 0,ctm,ctmString,open,close,high,low,vol,datetime
0,1136242800000,"Jan 3, 2006, 12:00:00 AM",12520.0,12688.0,12702.0,12457.0,0.0,2006-01-02 23:00:00
1,1136329200000,"Jan 4, 2006, 12:00:00 AM",12690.0,12735.0,12754.0,12677.0,0.0,2006-01-03 23:00:00
2,1136415600000,"Jan 5, 2006, 12:00:00 AM",12735.0,12735.0,12769.0,12703.0,0.0,2006-01-04 23:00:00
3,1136502000000,"Jan 6, 2006, 12:00:00 AM",12746.0,12855.0,12861.0,12746.0,0.0,2006-01-05 23:00:00
4,1136761200000,"Jan 9, 2006, 12:00:00 AM",12856.0,12902.0,12908.0,12848.0,0.0,2006-01-08 23:00:00


In [71]:
def xtb_preprocess(df, resample=True):
    df = df.set_index('datetime')
    df.drop(columns=['ctm', 'ctmString'], inplace=True)
    df = df[['high', 'low', 'close','vol']]
    if resample:
        df = df.resample(config.freq).agg({
                                                       'high': 'max',
                                                       'low': 'min',
                                                       'close': 'last',
                                                        'vol': 'sum'
                                                       })
    df['high'] = df['high'] - df['close']
    df['low'] = df['low'] - df['close']

    df.dropna(inplace=True)
    return df

In [72]:
df = xtb_preprocess(df, resample=False)
df['y_pred'] = df['close'].shift(-config.steps_ahead) / df['close']
df['y_pred'] = [1 if y > 1 else 0 for y in df['y_pred']]
df['datetime'] = df.index

In [73]:
df['y_pred'].mean()

0.5396648044692738

In [74]:
df

Unnamed: 0_level_0,high,low,close,vol,y_pred,datetime
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
2006-01-02 23:00:00,14.0,-231.0,12688.0,0.0,1,2006-01-02 23:00:00
2006-01-03 23:00:00,19.0,-58.0,12735.0,0.0,0,2006-01-03 23:00:00
2006-01-04 23:00:00,34.0,-32.0,12735.0,0.0,1,2006-01-04 23:00:00
2006-01-05 23:00:00,6.0,-109.0,12855.0,0.0,1,2006-01-05 23:00:00
2006-01-08 23:00:00,6.0,-54.0,12902.0,0.0,0,2006-01-08 23:00:00
...,...,...,...,...,...,...
2022-09-18 22:00:00,42.0,-767.0,39224.0,6044747.0,0,2022-09-18 22:00:00
2022-09-19 22:00:00,566.0,-361.0,38791.0,6849877.0,0,2022-09-19 22:00:00
2022-09-20 22:00:00,1281.0,-44.0,37963.0,7085644.0,0,2022-09-20 22:00:00
2022-09-21 22:00:00,588.0,-106.0,37736.0,7604778.0,0,2022-09-21 22:00:00


In [46]:
files_dict = {symbol: os.path.join(config.raw_data_path_xtb, symbol + '_' + str(period) + '.csv' ) for symbol in auxiliary_symbols}
for n, d in enumerate(files_dict.items()):
    print(n, d)
    if n == 0:
        df_aux = pd.read_csv(d[1], parse_dates=['datetime'])
        df_aux = xtb_preprocess(df_aux, resample=False)
        df_aux = df_aux.add_prefix(d[0] + '_')
    else:
        df_temp = pd.read_csv(d[1], parse_dates=['datetime'])
        df_temp = xtb_preprocess(df_temp, resample=False)
        df_temp = df_temp.add_prefix(d[0] + '_')
        df_aux = df_aux.merge(df_temp, left_index=True, right_index=True, how='outer')
        
    
    #yield pd.read_csv(file, **kwargs, parse_dates=['datetime'], names=names)

0 ('VIX', '/home/jupyter/ml_investing_wne/src/ml_investing_wne/data/raw/xtb/VIX_1440.csv')
1 ('TNOTE', '/home/jupyter/ml_investing_wne/src/ml_investing_wne/data/raw/xtb/TNOTE_1440.csv')


In [47]:
df_aux[df_aux.isna().any(axis=1)]

Unnamed: 0_level_0,VIX_high,VIX_low,VIX_close,VIX_vol,TNOTE_high,TNOTE_low,TNOTE_close,TNOTE_vol
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
2009-10-20 22:00:00,60.0,0.0,2630.0,0.0,,,,
2009-10-21 22:00:00,0.0,0.0,2650.0,0.0,,,,
2009-10-25 23:00:00,0.0,0.0,2610.0,0.0,,,,
2009-10-26 23:00:00,75.0,0.0,2585.0,0.0,,,,
2009-11-11 23:00:00,224.0,-51.0,2651.0,0.0,,,,
...,...,...,...,...,...,...,...,...
2018-03-13 23:00:00,20.0,-80.0,1721.0,615539.0,,,,
2018-03-14 23:00:00,86.0,-15.0,1663.0,324926.0,,,,
2019-10-26 22:00:00,15.0,-5.0,1525.0,720.0,,,,
2020-01-19 23:00:00,,,,,6.0,-2.0,12911.0,251879.0


In [75]:
df = prepare_processed_dataset(df=df)

[2022-09-25 13:40:00,012][prepare_processed_dataset:89] exported to /home/jupyter/ml_investing_wne/src/ml_investing_wne/data/processed/US500/US500_processed_1440min.csv


In [76]:
df

Unnamed: 0_level_0,high,low,close,vol,y_pred,SMA_3,EMA_3,VAR_3,SMA_5,EMA_5,VAR_5,SMA_10,EMA_10,VAR_10,SMA_13,EMA_13,VAR_13,SMA_20,EMA_20,VAR_20,MACD_12_26_9,MACDh_12_26_9,MACDs_12_26_9,RSI_14,RSI_10,RSI_6,STOCHk_14_3_3,STOCHd_14_3_3,WILLR_14,BBL_5_2.0,BBM_5_2.0,BBU_5_2.0,BBB_5_2.0,BBP_5_2.0,roc_1,hour,weekday,hour_sin,hour_cos,weekday_sin,weekday_cos
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1
2006-02-20 23:00:00,89.0,-17.0,12830.0,0.0,1.007560,12865.333333,12838.376377,1057.333333,12830.2,12816.730397,3096.2,12729.3,12775.835517,1.369979e+04,12714.692308,12765.283374,11270.064103,12733.00,12757.894817,9956.842105,0.000772,39.148612,-39.147839,55.855596,57.431194,60.284813,4928.914141,4921.464646,4815.530303,12730.661867,12830.2,12929.738133,1.551622,0.498995,0.996737,23,0,-2.449294e-16,1.000000,0.000000e+00,1.0
2006-02-21 23:00:00,15.0,-97.0,12927.0,0.0,0.996209,12876.333333,12882.688189,2366.333333,12864.6,12853.486931,2545.8,12767.2,12803.319968,1.279440e+04,12731.538462,12788.385749,14715.102564,12745.90,12774.000072,11546.936842,12.710798,41.486910,-28.776112,61.190579,64.183618,70.351860,4933.080808,4929.040404,4852.272727,12774.341704,12864.6,12954.858296,1.403204,0.845675,1.007560,23,1,-2.449294e-16,1.000000,8.660254e-01,0.5
2006-02-22 23:00:00,60.0,-27.0,12878.0,0.0,1.001242,12878.333333,12880.344094,2352.333333,12880.2,12861.657954,1243.2,12789.3,12816.898156,1.226646e+04,12749.846154,12801.187785,15441.307692,12757.45,12783.904827,11810.155263,18.615110,37.912978,-19.297868,57.415749,58.936349,60.981633,4978.632139,4946.875696,4968.093385,12817.126709,12880.2,12943.273291,0.979384,0.482560,0.996209,23,2,-2.449294e-16,1.000000,8.660254e-01,-0.5
2006-02-23 23:00:00,27.0,-38.0,12894.0,0.0,1.003645,12899.666667,12887.172047,624.333333,12880.2,12872.438636,1243.2,12814.9,12830.916673,1.021277e+04,12768.615385,12814.446673,15960.589744,12765.25,12794.390082,12707.565789,24.305211,34.882463,-10.577252,58.319918,60.119265,62.916962,5031.561726,4981.091558,4974.319066,12817.126709,12880.2,12943.273291,0.979384,0.609397,1.001242,23,3,-2.449294e-16,1.000000,1.224647e-16,-1.0
2006-02-26 23:00:00,35.0,-57.0,12941.0,0.0,0.989645,12904.333333,12914.086024,1072.333333,12894.0,12895.292424,1912.5,12842.0,12850.931823,8.830667e+03,12798.846154,12832.525719,13390.974359,12770.45,12808.352931,14033.839474,32.235573,34.250260,-2.014687,60.943287,63.546693,68.435716,5078.339818,5029.511228,4992.607004,12815.769571,12894.0,12972.230429,1.213439,0.800395,1.003645,23,6,-2.449294e-16,1.000000,-2.449294e-16,1.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-09-15 22:00:00,164.0,-350.0,38879.0,7437122.0,1.008874,39175.333333,39161.687832,117616.333333,39613.2,39380.560286,833257.2,39705.2,39680.450273,5.526288e+05,39659.846154,39823.849059,430273.141024,40127.70,40077.323793,874411.799998,-420.834707,-128.245595,-292.589112,38.827100,36.447879,33.120404,1262.962963,1311.732657,1153.635505,37980.281435,39613.2,41246.118565,8.244315,0.275188,0.994450,22,3,-2.697968e-01,0.962917,1.224647e-16,-1.0
2022-09-18 22:00:00,42.0,-767.0,39224.0,6044747.0,0.988961,39066.333333,39192.843916,30416.333333,39221.8,39328.373524,65134.7,39691.6,39597.459315,5.649112e+05,39631.230769,39738.156336,444876.525640,40015.85,39996.054860,810663.186840,-433.843272,-113.003327,-320.839944,42.170974,41.151116,41.171921,1259.532053,1263.498059,1164.494806,38765.257516,39221.8,39678.342484,2.328004,0.502409,1.008874,22,6,-2.697968e-01,0.962917,-2.449294e-16,1.0
2022-09-19 22:00:00,566.0,-361.0,38791.0,6849877.0,0.978655,38964.666667,38991.921958,52376.333333,39108.2,39149.249016,90694.7,39662.8,39450.830348,6.124120e+05,39563.153846,39602.848288,498521.141024,39891.70,39881.287731,790086.957892,-473.632394,-122.233960,-351.398434,39.269618,37.301484,34.852614,1256.331969,1259.608995,1150.865596,38569.476472,39108.2,39646.923528,2.755041,0.205600,0.988961,22,0,-2.697968e-01,0.962917,0.000000e+00,1.0
2022-09-20 22:00:00,1281.0,-44.0,37963.0,7085644.0,0.994020,38659.333333,38477.460979,410532.333333,38790.6,38753.832677,243460.3,39481.0,39180.315740,8.951716e+05,39463.846154,39368.584247,693246.974358,39715.05,39698.593661,817560.681577,-565.459970,-171.249229,-394.210741,34.396286,31.116623,25.774706,1246.721225,1254.195083,1124.803274,37907.948902,38790.6,39673.251098,4.550850,0.031185,0.978655,22,1,-2.697968e-01,0.962917,8.660254e-01,0.5


In [50]:
df = df.merge(df_aux, left_index=True, right_index=True, how='outer')

In [77]:
df[df.isna().any(axis=1)]

Unnamed: 0_level_0,high,low,close,vol,y_pred,SMA_3,EMA_3,VAR_3,SMA_5,EMA_5,VAR_5,SMA_10,EMA_10,VAR_10,SMA_13,EMA_13,VAR_13,SMA_20,EMA_20,VAR_20,MACD_12_26_9,MACDh_12_26_9,MACDs_12_26_9,RSI_14,RSI_10,RSI_6,STOCHk_14_3_3,STOCHd_14_3_3,WILLR_14,BBL_5_2.0,BBM_5_2.0,BBU_5_2.0,BBB_5_2.0,BBP_5_2.0,roc_1,hour,weekday,hour_sin,hour_cos,weekday_sin,weekday_cos
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1


In [52]:
df.loc[(df.isna().any(axis=1)) & (df.index>datetime.datetime(2010,1,1,0,0,0))]

Unnamed: 0_level_0,high,low,close,vol,y_pred,SMA_3,EMA_3,VAR_3,SMA_5,EMA_5,VAR_5,SMA_10,EMA_10,VAR_10,SMA_13,EMA_13,VAR_13,SMA_20,EMA_20,VAR_20,MACD_12_26_9,MACDh_12_26_9,MACDs_12_26_9,RSI_14,RSI_10,RSI_6,STOCHk_14_3_3,STOCHd_14_3_3,WILLR_14,BBL_5_2.0,BBM_5_2.0,BBU_5_2.0,BBB_5_2.0,BBP_5_2.0,roc_1,hour,weekday,hour_sin,hour_cos,weekday_sin,weekday_cos,VIX_high,VIX_low,VIX_close,VIX_vol,TNOTE_high,TNOTE_low,TNOTE_close,TNOTE_vol
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
2010-01-05 23:00:00,127.0,-339.0,11224.0,0.0,1.012830,10469.666667,10651.224630,427954.333333,10284.8,10445.912849,280072.7,10161.0,10216.433779,142572.888889,10085.000000,10126.055741,1.285828e+05,9937.00,9959.926278,1.331742e+05,305.508689,68.285167,237.223522,84.129642,87.586368,91.367729,3283.584978,3292.316056,2381.330472,9338.104357,10284.8,11231.495643,18.409607,0.996041,1.108324,23.0,1.0,-2.449294e-16,1.000000,8.660254e-01,0.5,,,,,,,,
2010-01-06 23:00:00,15.0,-99.0,11368.0,0.0,1.003519,10906.333333,11009.612315,460704.333333,10544.2,10753.275232,477865.2,10295.2,10425.809456,282409.288889,10209.307692,10303.476349,2.397761e+05,10035.20,10094.028537,2.158483e+05,379.687490,113.971175,265.716315,85.176660,88.580699,92.255308,2896.894845,3195.635986,2412.231760,9307.604415,10544.2,11780.795585,23.455465,0.833092,1.012830,23.0,2.0,-2.449294e-16,1.000000,8.660254e-01,-0.5,,,,,,,,
2010-01-07 23:00:00,5.0,-101.0,11408.0,0.0,0.947055,11333.333333,11208.806158,9365.333333,10837.0,10971.516822,467178.0,10430.3,10604.389555,393416.455555,10327.769231,10461.265442,3.346045e+05,10128.30,10219.168676,2.933173e+05,436.668744,136.761943,299.906801,85.463546,88.856199,92.511951,2504.792561,2895.090794,2420.815451,9614.310505,10837.0,12059.689495,22.565092,0.733502,1.003519,23.0,3.0,-2.449294e-16,1.000000,1.224647e-16,-1.0,,,,,,,,
2010-01-10 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,330.0,-50.0,1989.0,0.0,,,,
2010-01-11 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,65.0,-35.0,2009.0,0.0,89.0,-97.0,11593.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-01-19 23:00:00,54.0,-62.0,33244.0,1074634.0,0.998406,33194.000000,33177.903832,7351.000000,33067.6,33090.894781,35432.3,32866.8,32905.260549,74636.622225,32768.230769,32809.300768,9.416003e+04,32599.75,32606.778381,1.170506e+05,392.028712,32.372741,359.655971,76.860041,79.305775,83.960417,3391.117764,3379.618541,3296.107784,32730.875543,33067.6,33404.324457,2.036582,0.761935,1.000030,23.0,6.0,-2.449294e-16,1.000000,-2.449294e-16,1.0,,,,,6.0,-2.0,12911.0,251879.0
2020-03-07 23:00:00,1436.0,-117.0,28400.0,1373229.0,0.973275,29270.333333,29102.080123,844260.333333,29762.0,29511.350783,1059306.5,30102.7,30249.113606,914536.455558,30794.923077,30596.514528,2.507221e+06,31837.85,31168.449597,3.713023e+06,-980.719053,-352.640850,-628.078203,29.585571,27.864225,25.620492,913.840558,933.910023,788.174580,27920.864264,29762.0,31603.135736,12.372393,0.130120,0.973269,23.0,5.0,-2.449294e-16,1.000000,-8.660254e-01,0.5,,,,,16.0,-36.0,13929.0,65843.0
2020-03-21 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,282.0,-663.0,6963.0,3339.0,20.0,-9.0,13784.0,0.0
2021-04-01 22:00:00,108.0,-144.0,40269.0,188451.0,1.010703,40022.333333,40061.788451,103714.333333,39836.8,39894.262793,118610.7,39502.5,39647.075462,218788.277781,39383.230769,39554.219365,2.161959e+05,39417.30,39395.510124,1.555477e+05,296.989777,83.669157,213.320619,64.949418,69.036427,77.475935,3057.210936,3039.913051,2975.771257,39220.720265,39836.8,40452.879735,3.093018,0.850766,1.003214,22.0,3.0,-2.697968e-01,0.962917,1.224647e-16,-1.0,,,,,,,,


In [53]:
temp = df.loc[(df.index>datetime.datetime(2010,1,1,0,0,0))]
temp.loc[(temp.isna().any(axis=1))]

Unnamed: 0_level_0,high,low,close,vol,y_pred,SMA_3,EMA_3,VAR_3,SMA_5,EMA_5,VAR_5,SMA_10,EMA_10,VAR_10,SMA_13,EMA_13,VAR_13,SMA_20,EMA_20,VAR_20,MACD_12_26_9,MACDh_12_26_9,MACDs_12_26_9,RSI_14,RSI_10,RSI_6,STOCHk_14_3_3,STOCHd_14_3_3,WILLR_14,BBL_5_2.0,BBM_5_2.0,BBU_5_2.0,BBB_5_2.0,BBP_5_2.0,roc_1,hour,weekday,hour_sin,hour_cos,weekday_sin,weekday_cos,VIX_high,VIX_low,VIX_close,VIX_vol,TNOTE_high,TNOTE_low,TNOTE_close,TNOTE_vol
datetime,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
2010-01-05 23:00:00,127.0,-339.0,11224.0,0.0,1.012830,10469.666667,10651.224630,427954.333333,10284.8,10445.912849,280072.7,10161.0,10216.433779,142572.888889,10085.000000,10126.055741,1.285828e+05,9937.00,9959.926278,1.331742e+05,305.508689,68.285167,237.223522,84.129642,87.586368,91.367729,3283.584978,3292.316056,2381.330472,9338.104357,10284.8,11231.495643,18.409607,0.996041,1.108324,23.0,1.0,-2.449294e-16,1.000000,8.660254e-01,0.5,,,,,,,,
2010-01-06 23:00:00,15.0,-99.0,11368.0,0.0,1.003519,10906.333333,11009.612315,460704.333333,10544.2,10753.275232,477865.2,10295.2,10425.809456,282409.288889,10209.307692,10303.476349,2.397761e+05,10035.20,10094.028537,2.158483e+05,379.687490,113.971175,265.716315,85.176660,88.580699,92.255308,2896.894845,3195.635986,2412.231760,9307.604415,10544.2,11780.795585,23.455465,0.833092,1.012830,23.0,2.0,-2.449294e-16,1.000000,8.660254e-01,-0.5,,,,,,,,
2010-01-07 23:00:00,5.0,-101.0,11408.0,0.0,0.947055,11333.333333,11208.806158,9365.333333,10837.0,10971.516822,467178.0,10430.3,10604.389555,393416.455555,10327.769231,10461.265442,3.346045e+05,10128.30,10219.168676,2.933173e+05,436.668744,136.761943,299.906801,85.463546,88.856199,92.511951,2504.792561,2895.090794,2420.815451,9614.310505,10837.0,12059.689495,22.565092,0.733502,1.003519,23.0,3.0,-2.449294e-16,1.000000,1.224647e-16,-1.0,,,,,,,,
2010-01-10 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,330.0,-50.0,1989.0,0.0,,,,
2010-01-11 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,65.0,-35.0,2009.0,0.0,89.0,-97.0,11593.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2020-01-19 23:00:00,54.0,-62.0,33244.0,1074634.0,0.998406,33194.000000,33177.903832,7351.000000,33067.6,33090.894781,35432.3,32866.8,32905.260549,74636.622225,32768.230769,32809.300768,9.416003e+04,32599.75,32606.778381,1.170506e+05,392.028712,32.372741,359.655971,76.860041,79.305775,83.960417,3391.117764,3379.618541,3296.107784,32730.875543,33067.6,33404.324457,2.036582,0.761935,1.000030,23.0,6.0,-2.449294e-16,1.000000,-2.449294e-16,1.0,,,,,6.0,-2.0,12911.0,251879.0
2020-03-07 23:00:00,1436.0,-117.0,28400.0,1373229.0,0.973275,29270.333333,29102.080123,844260.333333,29762.0,29511.350783,1059306.5,30102.7,30249.113606,914536.455558,30794.923077,30596.514528,2.507221e+06,31837.85,31168.449597,3.713023e+06,-980.719053,-352.640850,-628.078203,29.585571,27.864225,25.620492,913.840558,933.910023,788.174580,27920.864264,29762.0,31603.135736,12.372393,0.130120,0.973269,23.0,5.0,-2.449294e-16,1.000000,-8.660254e-01,0.5,,,,,16.0,-36.0,13929.0,65843.0
2020-03-21 23:00:00,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,282.0,-663.0,6963.0,3339.0,20.0,-9.0,13784.0,0.0
2021-04-01 22:00:00,108.0,-144.0,40269.0,188451.0,1.010703,40022.333333,40061.788451,103714.333333,39836.8,39894.262793,118610.7,39502.5,39647.075462,218788.277781,39383.230769,39554.219365,2.161959e+05,39417.30,39395.510124,1.555477e+05,296.989777,83.669157,213.320619,64.949418,69.036427,77.475935,3057.210936,3039.913051,2975.771257,39220.720265,39836.8,40452.879735,3.093018,0.850766,1.003214,22.0,3.0,-2.697968e-01,0.962917,1.224647e-16,-1.0,,,,,,,,


In [78]:
df.dropna(inplace=True)

In [79]:
df.shape

(3546, 41)

In [80]:
X, y, X_val, y_val, X_test, y_test, y_cat, y_val_cat, y_test_cat, train = train_test_val_split(df)

[2022-09-25 13:40:18,769][split_sequences:24] first sequence begins: 2006-02-20 23:00:00
[2022-09-25 13:40:18,770][split_sequences:25] first sequence ends: 2006-04-27 22:00:00
[2022-09-25 13:40:18,784][split_sequences:30] last sequence begins: 2020-10-25 23:00:00
[2022-09-25 13:40:18,785][split_sequences:31] last sequence ends: 2020-12-30 23:00:00
[2022-09-25 13:40:18,804][split_sequences:24] first sequence begins: 2020-11-15 23:00:00
[2022-09-25 13:40:18,805][split_sequences:25] first sequence ends: 2021-01-21 23:00:00
[2022-09-25 13:40:18,807][split_sequences:30] last sequence begins: 2021-10-26 22:00:00
[2022-09-25 13:40:18,808][split_sequences:31] last sequence ends: 2021-12-30 23:00:00
[2022-09-25 13:40:18,810][split_sequences:24] first sequence begins: 2021-11-14 23:00:00
[2022-09-25 13:40:18,811][split_sequences:25] first sequence ends: 2022-01-19 23:00:00
[2022-09-25 13:40:18,812][split_sequences:30] last sequence begins: 2022-07-18 22:00:00
[2022-09-25 13:40:18,813][split_sequ

In [94]:
config.patience = 20

In [105]:
mlflow.set_experiment(experiment_name='xtb' + '_' + config.model + '_' +
                                      str(config.nb_classes) + '_' +
                                      config.freq + '_' + str(config.steps_ahead) + '_' +
                                      str(config.seq_len))
early_stop = EarlyStopping(monitor='val_accuracy', patience=config.patience, restore_best_weights=True)
model_path_final = os.path.join(config.package_directory, 'models',
                                '{}_{}_{}_{}_{}.h5'.format(config.model, 'xtb',
                                                           config.currency, config.freq,
                                                           config.steps_ahead))
model_checkpoint = ModelCheckpoint(filepath=model_path_final, monitor='val_accuracy', verbose=1,
                                   save_best_only=True)
csv_logger = CSVLogger(os.path.join(config.package_directory, 'logs', 'keras_log.csv'), append=True,
                       separator=';')
callbacks = [early_stop, model_checkpoint, csv_logger]

In [106]:
X.shape

(3047, 48, 40)

In [107]:
model = build_model(input_shape=(config.seq_len, X.shape[2]), head_size=64, num_heads=4, ff_dim=64,
                    num_transformer_blocks=2, mlp_units=[128], mlp_dropout=0.25, dropout=0.25)

In [108]:
model.summary()

Model: "model_7"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_8 (InputLayer)            [(None, 48, 40)]     0                                            
__________________________________________________________________________________________________
position_embedding_layer_7 (Pos (None, 48, 40)       1920        input_8[0][0]                    
__________________________________________________________________________________________________
layer_normalization_28 (LayerNo (None, 48, 40)       80          position_embedding_layer_7[0][0] 
__________________________________________________________________________________________________
multi_head_attention_14 (MultiH (None, 48, 40)       41768       layer_normalization_28[0][0]     
                                                                 layer_normalization_28[0][0

In [109]:
history = model.fit(X, y_cat, batch_size=64, epochs=config.epochs, verbose=2,
                    validation_data=(X_val, y_val_cat), callbacks=callbacks)

Epoch 1/100
48/48 - 6s - loss: 0.7195 - accuracy: 0.5021 - val_loss: 0.6820 - val_accuracy: 0.5806

Epoch 00001: val_accuracy improved from -inf to 0.58065, saving model to /home/jupyter/ml_investing_wne/src/ml_investing_wne/models/transformer_learnable_encoding_xtb_US500_1440min_1.h5
Epoch 2/100
48/48 - 4s - loss: 0.6947 - accuracy: 0.5323 - val_loss: 0.6926 - val_accuracy: 0.5242

Epoch 00002: val_accuracy did not improve from 0.58065
Epoch 3/100
48/48 - 5s - loss: 0.6938 - accuracy: 0.5267 - val_loss: 0.6777 - val_accuracy: 0.5887

Epoch 00003: val_accuracy improved from 0.58065 to 0.58871, saving model to /home/jupyter/ml_investing_wne/src/ml_investing_wne/models/transformer_learnable_encoding_xtb_US500_1440min_1.h5
Epoch 4/100
48/48 - 4s - loss: 0.6937 - accuracy: 0.5310 - val_loss: 0.6760 - val_accuracy: 0.5887

Epoch 00004: val_accuracy did not improve from 0.58871
Epoch 5/100
48/48 - 4s - loss: 0.6939 - accuracy: 0.5290 - val_loss: 0.6836 - val_accuracy: 0.5726

Epoch 00005: va

In [111]:
model.evaluate(X_test, y_test_cat)



[0.7526242733001709, 0.47457626461982727]

In [112]:
y_pred = model.predict(X_test)
y_pred_class = y_pred.argmax(axis=-1)

if 'JPY' in config.currency:
    df['cost'] = (config.pips / 100) / df['close']
else:
    df['cost'] = (config.pips / 10000) / df['close']

start_date = joblib.load(os.path.join(config.package_directory, 'models',
                                      'first_sequence_ends_{}_{}_{}.save'.format('test',
                                                                                 config.currency,
                                                                                 config.freq)))
end_date = joblib.load(os.path.join(config.package_directory, 'models',
                                    'last_sequence_ends_{}_{}_{}.save'.format('test',
                                                                              config.currency,
                                                                              config.freq)))
lower_bounds = [0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5]
upper_bounds = [1 - lower for lower in lower_bounds]

for lower_bound, upper_bound in zip(lower_bounds, upper_bounds):
    portfolio_result, hit_ratio, time_active = compute_profitability_classes(df, y_pred, start_date,
                                                                             end_date, lower_bound,
                                                                             upper_bound)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  'Europe/Warsaw').dt.tz_localize(None)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  prediction['hour_waw'] = prediction['datetime_waw'].dt.time
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  prediction['prediction'] = y_pred[:, 1]
A value is trying to be set on a copy of a slice from a DataFrame.


In [113]:
y_pred

array([[0.3425434 , 0.6574567 ],
       [0.32381877, 0.67618126],
       [0.31618667, 0.68381333],
       [0.23020212, 0.7697979 ],
       [0.3261649 , 0.67383516],
       [0.38393235, 0.61606765],
       [0.32500538, 0.67499465],
       [0.4629349 , 0.537065  ],
       [0.41527194, 0.58472806],
       [0.3809217 , 0.6190783 ],
       [0.3650233 , 0.63497674],
       [0.37059683, 0.6294031 ],
       [0.43136576, 0.5686343 ],
       [0.39708212, 0.60291785],
       [0.49377564, 0.50622433],
       [0.35678613, 0.64321387],
       [0.3447483 , 0.6552517 ],
       [0.43035328, 0.5696468 ],
       [0.4010485 , 0.5989515 ],
       [0.40953344, 0.5904665 ],
       [0.35486674, 0.64513326],
       [0.30248764, 0.6975124 ],
       [0.33466938, 0.6653306 ],
       [0.3077956 , 0.69220436],
       [0.32982114, 0.67017883],
       [0.31683165, 0.68316835],
       [0.30651727, 0.6934827 ],
       [0.30401242, 0.69598764],
       [0.29965428, 0.7003457 ],
       [0.33384892, 0.66615105],
       [0.