In [1]:
import warnings
warnings.filterwarnings("ignore")

import re
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch

from DataLoader.DataLoader import DataLoader
from DataLoader.DataBasedAgent import DataBasedAgent
from DataLoader.DataRLAgent import DataRLAgent
import DeepRLAgent.VanillaInput.Train as Train
from PatternDetectionInCandleStick.Evaluation import Evaluation
import distinctipy
import talib


from importlib import reload

Train = reload(Train)
DeepRL = Train.Train
from utils_best_arm import add_train_portfo, add_test_portfo, plot_return, calc_return, plot_action_point, setup_logger
pd.options.display.max_colwidth = 100
from scipy.optimize import minimize

device = "cpu"
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import talib
CURRENT_PATH = os.getcwd()

In [2]:
def train(
    DATASET_NAME, 
    split_point='2018-01-01', 
    begin_date='2010-01-01', 
    end_date='2020-08-24', 
    model_start_date="",
    model_end_date="",
    initial_investment=1000,
    transaction_cost=0.0001,
    load_from_file=True,
    reward_type="profit",
    seed=42, 
    state_mode=1,
    n_episodes=5,
    lamb=0.0001,
    GAMMA=0.7, 
    n_step=5, 
    BATCH_SIZE=10, 
    ReplayMemorySize=20,
    TARGET_UPDATE=5,
    window_size=None, 
    train_portfolios={},
    test_portfolios={},
    arms=[],
    show_all = False,
    ratio_threshold=0.9,
):
    data_loader = DataLoader(DATASET_NAME, split_point=split_point, begin_date=begin_date, end_date=end_date, load_from_file=load_from_file)
    
    dataTrain_agent = DataRLAgent(data_loader.data_train, state_mode, 'action_encoder_decoder', device, GAMMA, n_step, BATCH_SIZE, window_size, transaction_cost)
    dataTest_agent = DataRLAgent(data_loader.data_test, state_mode, 'action_encoder_decoder', device, GAMMA, n_step, BATCH_SIZE, window_size, transaction_cost)
    
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)

    agent = DeepRL(data_loader, dataTrain_agent, dataTest_agent, 
                DATASET_NAME,  state_mode, window_size, transaction_cost,
                BATCH_SIZE=BATCH_SIZE, GAMMA=GAMMA, ReplayMemorySize=ReplayMemorySize,
                TARGET_UPDATE=TARGET_UPDATE, n_step=n_step, arms=arms)
    
    path = f"./Results/{DATASET_NAME}/{model_start_date}~{model_end_date}/{seed}/train"
    arm = arms[0]
    name = f'{arm["name"]}_{arm["lamb"]}'
    model_path = f"{path}/model_{name}_{seed}.pkl"
    agent_test = agent.test_MV(initial_investment=initial_investment, test_type='test', model_path=model_path, symbol=DATASET_NAME)
    test_portfolio = agent_test.get_daily_portfolio_value()
    test_portfolio = pd.Series(test_portfolio).pct_change(1).fillna(0).values.tolist() 
    model_name = f'DQN-stock:{DATASET_NAME}-reward:{name}-seed:{seed}'
    return data_loader, {"name": model_name, "portfo": test_portfolio}

In [4]:
initial_investment = 1000


kwargs = {
    "load_from_file": True, 
    "transaction_cost": 0.0000,
    "initial_investment": initial_investment,
    "state_mode": 1,
    "GAMMA": 0.7, 
    "n_step": 5, 
    "BATCH_SIZE": 10, 
    "ReplayMemorySize": 20,
    "TARGET_UPDATE": 5,
    "window_size": None, 
    "lamb": 0.0,
}

_file = "AAPL"

_begin_date = '20{}-01-01'
_end_date = '20{}-01-01'
_split_point = '20{}-01-01' 

arms = [
    { "name": "profit", "lamb": 0},
]

dates = [
    ("2019-06-23", "2020-06-22", "2016-01-01", "2019-01-01"),
    ("2020-06-22", "2021-06-22", "2017-01-01", "2020-01-01"),
    ("2021-06-23", "2022-06-23", "2018-01-01", "2021-01-01"),
]

random_seeds = 1
results = []
portfolios_saved = {}
files = sorted(os.listdir("./Data/"))
for _file in files[:]:
    print(_file)
    portfolios_saved[_file] = {}

    for idx, arm in enumerate(arms[:]):
        results2 = []
        portfolios_saved[_file][f"{arm['name']}-{arm['lamb']}"] = []
        
        for seed in tqdm(range(random_seeds)):
            
            ls = []
            bhs = []
        
            train_portfolios = {}
            test_portfolios = {}
            tmp_result = []
            
            for date in dates:
                model_start_date = date[2]
                model_end_date = date[3]
                split_point = date[0]
                end_date = date[1]

                kwargs.update({
                    "begin_date": "2016-01-01", 
                    "end_date": end_date, 
                    "split_point": split_point,
                    "model_start_date": model_start_date,
                    "model_end_date": model_end_date,
                    "DATASET_NAME": _file,
                    "reward_type": "",
                    "seed": seed,
                    "n_episodes": 140,
                    "arms": [arm],
                    "show_all": True,
                    "ratio_threshold": 3,
                    "train_portfolios": train_portfolios,
                    "test_portfolios": test_portfolios,
                })

                data_loader, model = train(**kwargs)
                ls.extend(model["portfo"])
                bh = data_loader.data_test_with_date["close"]
                bhs.append(bh)
                tmp_result.extend(model["portfo"])

            add_test_portfo(test_portfolios, seed, ls)
            if seed == 0: 
                bhs = pd.concat(bhs, axis=0)
                bh_percentage = bhs.pct_change(1).fillna(0).values
                add_test_portfo(test_portfolios, 'B&H', bh_percentage)
            indexes = calc_return(bh_percentage, test_portfolios)
            results2.append(indexes)
            portfolios_saved[_file][f"{arm['name']}-{arm['lamb']}"].append(tmp_result)

        # path = f"./Results/{_file}/exp3_concat"
        # if not os.path.exists(path):
        #     os.mkdir(path)
        
        # save_path = f"{path}/MV.csv"

        # portfolios_saved[_file][f"{arm['name']}-{arm['lamb']}"].insert(0, bh_percentage.tolist())
        # _df_ = pd.DataFrame(portfolios_saved[_file][f"{arm['name']}-{arm['lamb']}"]).T.fillna(0)
        # _df_.to_csv(save_path, index=False)

        results2_df = pd.concat(results2, axis=1)
        results2_bh = results2_df["B&H"]
        del results2_df["B&H"]
        final = pd.concat([
            results2_bh,
            results2_df.median(axis=1)
        ], axis=1)
        final.columns = [f"{_file}-B&H", f"{_file}-{arm['name']}-{arm['lamb']}"]
        if idx > 0:
            del final[f"{_file}-B&H"]
        results.append(final)


AAPL


100%|██████████| 252/252 [00:00<00:00, 263.13it/s]
100%|██████████| 253/253 [00:00<00:00, 284.41it/s]
100%|██████████| 253/253 [00:00<00:00, 277.30it/s]
100%|██████████| 1/1 [00:03<00:00,  3.32s/it]


AMGN


100%|██████████| 252/252 [00:00<00:00, 303.65it/s]
100%|██████████| 253/253 [00:01<00:00, 210.91it/s]
100%|██████████| 253/253 [00:01<00:00, 159.42it/s]
100%|██████████| 1/1 [00:04<00:00,  4.43s/it]


AXP


100%|██████████| 252/252 [00:01<00:00, 141.26it/s]
100%|██████████| 253/253 [00:01<00:00, 135.54it/s]
100%|██████████| 253/253 [00:01<00:00, 162.71it/s]
100%|██████████| 1/1 [00:06<00:00,  6.12s/it]


BA


100%|██████████| 252/252 [00:01<00:00, 179.91it/s]
100%|██████████| 253/253 [00:01<00:00, 172.42it/s]
100%|██████████| 253/253 [00:01<00:00, 183.09it/s]
100%|██████████| 1/1 [00:05<00:00,  5.01s/it]


CAT


100%|██████████| 252/252 [00:01<00:00, 198.65it/s]
100%|██████████| 253/253 [00:01<00:00, 175.86it/s]
100%|██████████| 253/253 [00:01<00:00, 178.97it/s]
100%|██████████| 1/1 [00:04<00:00,  4.91s/it]


CRM


100%|██████████| 252/252 [00:01<00:00, 184.99it/s]
100%|██████████| 253/253 [00:01<00:00, 175.74it/s]
100%|██████████| 253/253 [00:01<00:00, 180.64it/s]
100%|██████████| 1/1 [00:04<00:00,  4.98s/it]


CSCO


100%|██████████| 252/252 [00:01<00:00, 185.65it/s]
100%|██████████| 253/253 [00:01<00:00, 185.94it/s]
100%|██████████| 253/253 [00:01<00:00, 173.60it/s]
100%|██████████| 1/1 [00:04<00:00,  4.98s/it]


CVX


100%|██████████| 252/252 [00:01<00:00, 209.59it/s]
100%|██████████| 253/253 [00:01<00:00, 180.51it/s]
100%|██████████| 253/253 [00:01<00:00, 182.90it/s]
100%|██████████| 1/1 [00:04<00:00,  4.75s/it]


DIS


100%|██████████| 252/252 [00:01<00:00, 183.89it/s]
100%|██████████| 253/253 [00:01<00:00, 187.32it/s]
100%|██████████| 253/253 [00:01<00:00, 179.93it/s]
100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


GS


100%|██████████| 252/252 [00:01<00:00, 186.43it/s]
100%|██████████| 253/253 [00:01<00:00, 190.32it/s]
100%|██████████| 253/253 [00:01<00:00, 181.96it/s]
100%|██████████| 1/1 [00:04<00:00,  4.89s/it]


HD


100%|██████████| 252/252 [00:01<00:00, 186.14it/s]
100%|██████████| 253/253 [00:01<00:00, 194.74it/s]
100%|██████████| 253/253 [00:01<00:00, 178.87it/s]
100%|██████████| 1/1 [00:04<00:00,  4.86s/it]


HON


100%|██████████| 252/252 [00:01<00:00, 184.25it/s]
100%|██████████| 253/253 [00:01<00:00, 179.54it/s]
100%|██████████| 253/253 [00:01<00:00, 186.04it/s]
100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


IBM


100%|██████████| 252/252 [00:01<00:00, 194.36it/s]
100%|██████████| 253/253 [00:01<00:00, 186.71it/s]
100%|██████████| 253/253 [00:01<00:00, 180.40it/s]
100%|██████████| 1/1 [00:04<00:00,  4.84s/it]


INTC


100%|██████████| 252/252 [00:01<00:00, 189.39it/s]
100%|██████████| 253/253 [00:01<00:00, 191.12it/s]
100%|██████████| 253/253 [00:01<00:00, 193.02it/s]
100%|██████████| 1/1 [00:04<00:00,  4.78s/it]


JNJ


100%|██████████| 252/252 [00:01<00:00, 181.99it/s]
100%|██████████| 253/253 [00:01<00:00, 189.77it/s]
100%|██████████| 253/253 [00:01<00:00, 174.62it/s]
100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


JPM


100%|██████████| 252/252 [00:01<00:00, 180.39it/s]
100%|██████████| 253/253 [00:01<00:00, 187.14it/s]
100%|██████████| 253/253 [00:01<00:00, 190.73it/s]
100%|██████████| 1/1 [00:04<00:00,  4.86s/it]


KO


100%|██████████| 252/252 [00:01<00:00, 182.63it/s]
100%|██████████| 253/253 [00:01<00:00, 187.20it/s]
100%|██████████| 253/253 [00:01<00:00, 172.96it/s]
100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


MCD


100%|██████████| 252/252 [00:01<00:00, 189.20it/s]
100%|██████████| 253/253 [00:01<00:00, 183.97it/s]
100%|██████████| 253/253 [00:01<00:00, 175.93it/s]
100%|██████████| 1/1 [00:04<00:00,  4.95s/it]


MMM


100%|██████████| 252/252 [00:01<00:00, 186.73it/s]
100%|██████████| 253/253 [00:01<00:00, 189.97it/s]
100%|██████████| 253/253 [00:01<00:00, 169.46it/s]
100%|██████████| 1/1 [00:04<00:00,  4.96s/it]


MRK


100%|██████████| 252/252 [00:01<00:00, 185.72it/s]
100%|██████████| 253/253 [00:01<00:00, 185.77it/s]
100%|██████████| 253/253 [00:01<00:00, 174.63it/s]
100%|██████████| 1/1 [00:04<00:00,  4.97s/it]


MSFT


100%|██████████| 252/252 [00:01<00:00, 186.75it/s]
100%|██████████| 253/253 [00:01<00:00, 170.84it/s]
100%|██████████| 253/253 [00:01<00:00, 181.83it/s]
100%|██████████| 1/1 [00:04<00:00,  5.00s/it]


NKE


100%|██████████| 252/252 [00:01<00:00, 187.30it/s]
100%|██████████| 253/253 [00:01<00:00, 180.26it/s]
100%|██████████| 253/253 [00:01<00:00, 183.82it/s]
100%|██████████| 1/1 [00:04<00:00,  4.92s/it]


PG


100%|██████████| 252/252 [00:01<00:00, 183.34it/s]
100%|██████████| 253/253 [00:01<00:00, 196.63it/s]
100%|██████████| 253/253 [00:01<00:00, 191.53it/s]
100%|██████████| 1/1 [00:04<00:00,  4.77s/it]


TRV


100%|██████████| 252/252 [00:01<00:00, 182.11it/s]
100%|██████████| 253/253 [00:01<00:00, 183.19it/s]
100%|██████████| 253/253 [00:01<00:00, 187.01it/s]
100%|██████████| 1/1 [00:04<00:00,  4.92s/it]


UNH


100%|██████████| 252/252 [00:01<00:00, 180.58it/s]
100%|██████████| 253/253 [00:01<00:00, 174.68it/s]
100%|██████████| 253/253 [00:01<00:00, 183.37it/s]
100%|██████████| 1/1 [00:04<00:00,  5.00s/it]


V


100%|██████████| 252/252 [00:01<00:00, 183.07it/s]
100%|██████████| 253/253 [00:01<00:00, 187.71it/s]
100%|██████████| 253/253 [00:01<00:00, 177.06it/s]
100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


VZ


100%|██████████| 252/252 [00:01<00:00, 193.82it/s]
100%|██████████| 253/253 [00:01<00:00, 175.60it/s]
100%|██████████| 253/253 [00:01<00:00, 187.84it/s]
100%|██████████| 1/1 [00:04<00:00,  4.91s/it]


WBA


100%|██████████| 252/252 [00:01<00:00, 184.12it/s]
100%|██████████| 253/253 [00:01<00:00, 185.37it/s]
100%|██████████| 253/253 [00:01<00:00, 179.99it/s]
100%|██████████| 1/1 [00:04<00:00,  4.93s/it]


WMT


100%|██████████| 252/252 [00:01<00:00, 194.16it/s]
100%|██████████| 253/253 [00:01<00:00, 178.45it/s]
100%|██████████| 253/253 [00:01<00:00, 196.09it/s]
100%|██████████| 1/1 [00:04<00:00,  4.75s/it]


In [7]:
def output_bh(symbol):
    raw_df = pd.read_csv(f"./Data/{symbol}/{symbol}.csv")
    raw_df = raw_df[["Date", "Close"]]
    raw_df["pct"] = raw_df["Close"].pct_change(1)
    raw_df = raw_df.query("Date >= '2016-01-01'")
    raw_df = raw_df.set_index("Date")
    del raw_df["Close"]
    raw_df = raw_df.reset_index()
    dates = [
        ("2019-06-23", "2020-06-22", "2016-01-01", "2019-01-01"),
        ("2020-06-22", "2021-06-22", "2017-01-01", "2020-01-01"),
        ("2021-06-23", "2022-06-23", "2018-01-01", "2021-01-01"),
    ]
    ls = []
    for date in dates:
        date1, date2, _, _ = date
        ls.append(raw_df.query(f"Date >= '{date1}' & Date <= '{date2}'"))
    bh = pd.concat(ls)
    bh.reset_index(inplace=True)
    del bh["index"]
    return bh

symbol = "AAPL"
# output_bh(symbol)

In [9]:
symbols = list(portfolios_saved.keys())[:]
# plt.style.use("ggplot")
# plt.rcParams["text.color"] = "black"

ls = []
for symbol in symbols:
    bh = output_bh(symbol)
    res = pd.concat([
        bh,
        pd.DataFrame(portfolios_saved[symbol]["profit-0"]).T
    ], axis=1).dropna()
    res.columns = ["Date", f"{symbol}-B&H", f"{symbol}-MV"]
    # NOTE cumreturn
    res[f"{symbol}-B&H"] = (1 + res[f"{symbol}-B&H"]).cumprod() - 1
    res[f"{symbol}-MV"] = (1 + res[f"{symbol}-MV"]).cumprod() - 1 
    date_ls = res["Date"].tolist()
    del res["Date"]
    ls.append(res)


results_cumreturn = pd.concat(ls, axis=1)   
results_cumreturn["date"] = date_ls 
results_cumreturn = results_cumreturn.set_index("date")
results_cumreturn.to_csv("./ts-run-results/[exp3]MV-returns.csv",)
results_cumreturn.to_csv("./ts-run-results/[exp3]MV-cumreturns.csv")
results_cumreturn

Unnamed: 0_level_0,AAPL-B&H,AAPL-MV,AMGN-B&H,AMGN-MV,AXP-B&H,AXP-MV,BA-B&H,BA-MV,CAT-B&H,CAT-MV,...,UNH-B&H,UNH-MV,V-B&H,V-MV,VZ-B&H,VZ-MV,WBA-B&H,WBA-MV,WMT-B&H,WMT-MV
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2019-06-24,-0.001006,0.000000,-0.012185,0.000000,-0.004730,0.000000,0.005782,0.000000,0.003585,0.000000,...,-0.010465,0.000000,0.002364,0.000000,0.008655,0.000000,-0.009533,0.000000,0.000990,0.000000
2019-06-25,-0.016148,0.000000,-0.014270,0.000000,-0.012587,0.000000,-0.006777,0.000000,-0.001344,0.000000,...,-0.018313,0.000000,-0.012454,0.000000,0.003981,0.000000,0.009533,0.000000,-0.003689,0.000000
2019-06-26,0.005131,0.021629,-0.021057,-0.006886,-0.009140,0.003491,0.008337,0.015217,0.009411,0.010769,...,-0.034922,-0.016918,-0.013722,-0.001284,-0.013502,-0.017414,-0.001335,-0.010765,-0.008729,-0.005058
2019-06-27,0.004829,0.021322,-0.016514,-0.002277,-0.006334,0.006333,-0.021031,-0.014351,0.012025,0.013387,...,-0.023545,-0.005330,-0.012742,-0.000292,-0.009001,-0.012931,0.039466,0.029651,-0.009269,-0.005600
2019-06-28,-0.004326,0.012016,-0.015125,-0.000868,-0.010342,0.002273,-0.021057,-0.014378,0.017925,0.019296,...,-0.032781,-0.014738,0.000634,0.013253,-0.011078,-0.015000,0.042326,0.032484,-0.005759,-0.002077
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2022-06-16,1.751635,1.291420,0.320943,0.046533,0.132091,0.125285,-0.631919,-0.488534,0.548349,0.012026,...,0.901021,0.366777,0.126761,0.250470,-0.043836,-0.061948,-0.165308,-0.262867,0.157549,0.176075
2022-06-17,1.783370,1.291420,0.343902,0.046533,0.187090,0.125285,-0.622424,-0.488534,0.531259,0.012026,...,0.884224,0.366777,0.132483,0.250470,-0.040116,-0.061948,-0.172672,-0.262867,0.135189,0.176075
2022-06-21,1.874555,1.291420,0.365029,0.046533,0.191866,0.125285,-0.622562,-0.488534,0.560908,0.012026,...,1.002014,0.366777,0.158588,0.250470,-0.008401,-0.061948,-0.156260,-0.262867,0.172424,0.176075
2022-06-22,1.863553,1.291420,0.374934,0.046533,0.188819,0.125285,-0.621430,-0.488534,0.493025,0.012026,...,1.041027,0.366777,0.155191,0.250470,-0.005660,-0.061948,-0.148896,-0.262867,0.158221,0.176075


In [10]:
results_df = pd.concat(results, axis=1)
cols = results_df.columns
ls = []
for col in cols:
    if "B&H" in col:
        ls.append(col)
    else:
        symbol = col.split("-")[0]
        ls.append(f"{symbol}-MV")

results_df.columns = ls
results_df.to_csv(f"./ts-run-results/[exp3]MV_reward.csv")

def find(symbol):
    ls = []
    for col in results_df.columns:
        if symbol in col:
            ls.append(col)

    return results_df[ls]

files = os.listdir("./Results/")
df = find(files[0])
df


Unnamed: 0,AAPL-B&H,AAPL-MV
sortino_test,1.617304,1.714784
sharpe_test,1.17039,1.247741
risk_test,0.350629,0.244047
mdd_test,0.314273,0.230526
downrisk_test,0.253739,0.177578
cumreturn_test,1.853616,1.29142
