  # Distributed Deep Reinforcement Learning for Multiple Stock Trading

  <a id='0'></a>
  # Part 1. Problem Definition

  This problem is to design an automated trading solution for single stock trading. We model the stock trading process as a Markov Decision Process (MDP). We then formulate our trading goal as a maximization problem.

  The algorithm is trained using Deep Reinforcement Learning (DRL) algorithms and the components of the reinforcement learning environment are:


  * Action: The action space describes the allowed actions that the agent interacts with the
  environment. Normally, a ∈ A includes three actions: a ∈ {−1, 0, 1}, where −1, 0, 1 represent
  selling, holding, and buying one stock. Also, an action can be carried upon multiple shares. We use
  an action space {−k, ..., −1, 0, 1, ..., k}, where k denotes the number of shares. For example, "Buy
  10 shares of AAPL" or "Sell 10 shares of AAPL" are 10 or −10, respectively

  * Reward function: r(s, a, s′) is the incentive mechanism for an agent to learn a better action. The change of the portfolio value when action a is taken at state s and arriving at new state s',  i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio
  values at state s′ and s, respectively

  * State: The state space describes the observations that the agent receives from the environment. Just as a human trader needs to analyze various information before executing a trade, so
  our trading agent observes many different features to better learn in an interactive environment.

  * Environment: Dow 30 consituents


  The data of the single stock that we will be using for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.


  <a id='1'></a>
  # Part 2. Getting Started- Load Python Packages

  <a id='1.1'></a>
  ## 2.1. Install all required packages


In [None]:
!pip install git+https://github.com/AI4Finance-LLC/FinRL-Library.git
!git clone https://github.com/facebookresearch/torchbeast.git
!pip install -r torchbeast/requirements.txt

Collecting git+https://github.com/AI4Finance-LLC/FinRL-Library.git
  Cloning https://github.com/AI4Finance-LLC/FinRL-Library.git to /tmp/pip-req-build-_i5yc_72
  Running command git clone -q https://github.com/AI4Finance-LLC/FinRL-Library.git /tmp/pip-req-build-_i5yc_72
Collecting pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2
  Cloning https://github.com/quantopian/pyfolio.git to /tmp/pip-install-q6huydb8/pyfolio_a3fed737088a4c58b38143b5f581403f
  Running command git clone -q https://github.com/quantopian/pyfolio.git /tmp/pip-install-q6huydb8/pyfolio_a3fed737088a4c58b38143b5f581403f
Collecting elegantrl@ git+https://github.com/AI4Finance-Foundation/ElegantRL.git#egg=elegantrl
  Cloning https://github.com/AI4Finance-Foundation/ElegantRL.git to /tmp/pip-install-q6huydb8/elegantrl_f561910b1ef64e5db188b27ad2efa1c2
  Running command git clone -q https://github.com/AI4Finance-Foundation/ElegantRL.git /tmp/pip-install-q6huydb8/elegantrl_f561910b1ef64e5db188b27ad2efa


  <a id='1.2'></a>
  ## 2.2. Check if the additional packages needed are present, if not install them.
  * Yahoo Finance API
  * pandas
  * numpy
  * matplotlib
  * stockstats
  * OpenAI gym
  * stable-baselines
  * tensorflow
  * pyfolio

  <a id='1.3'></a>
  ## 2.3. Import Packages

In [None]:
import sys, os
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import datetime

%reload_ext autoreload
%autoreload 2
%matplotlib inline

from finrl.apps import config
from finrl.finrl_meta.preprocessor.yahoodownloader import YahooDownloader
from finrl.finrl_meta.preprocessor.preprocessors import FeatureEngineer, data_split
from finrl.finrl_meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.plot import backtest_stats, backtest_plot, get_baseline

import impala

from pprint import pprint
from absl import flags

FLAGS = flags.FLAGS
FLAGS(sys.argv, known_only=True)

import sys
sys.path.append("../FinRL-Library")

import itertools
import logging

logging.basicConfig(level=logging.INFO)

  'Module "zipline.assets" not found; multipliers will not be applied'


  <a id='1.4'></a>
  ## 2.4. Create Folders

In [None]:
if not os.path.exists("./" + config.DATA_SAVE_DIR):
    os.makedirs("./" + config.DATA_SAVE_DIR)
if not os.path.exists("./" + config.TRAINED_MODEL_DIR):
    os.makedirs("./" + config.TRAINED_MODEL_DIR)
if not os.path.exists("./" + config.TENSORBOARD_LOG_DIR):
    os.makedirs("./" + config.TENSORBOARD_LOG_DIR)
if not os.path.exists("./" + config.RESULTS_DIR):
    os.makedirs("./" + config.RESULTS_DIR)


  <a id='2'></a>
  # Part 3. Download Data
  Yahoo Finance is a website that provides stock data, financial news, financial reports, etc. All the data provided by Yahoo Finance is free.
  * FinRL uses a class **YahooDownloader** to fetch data from Yahoo Finance API
  * Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day).




  -----
  class YahooDownloader:
      Provides methods for retrieving daily stock data from
      Yahoo Finance API

      Attributes
      ----------
          start_date : str
              start date of the data (modified from config.py)
          end_date : str
              end date of the data (modified from config.py)
          ticker_list : list
              a list of stock tickers (modified from config.py)

      Methods
      -------
      fetch_data()
          Fetches data from yahoo API


In [None]:
# from config.py start_date is a string
config.START_DATE

'2009-01-01'

In [None]:
# from config.py end_date is a string
config.END_DATE

'2021-10-31'

In [None]:
print(config.DOW_30_TICKER)

['AXP', 'AMGN', 'AAPL', 'BA', 'CAT', 'CSCO', 'CVX', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'KO', 'JPM', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'CRM', 'VZ', 'V', 'WBA', 'WMT', 'DIS', 'DOW']


In [None]:
data_filename = 'processed_data.csv'
data_path = os.path.join(config.DATA_SAVE_DIR, data_filename)

if os.path.exists(data_path):
    data = pd.read_csv(data_path)
else:
    data = None

In [None]:
if data is None:
    df = YahooDownloader(
        start_date = '2009-01-01',
        end_date = '2021-10-31',
        ticker_list = config.DOW_30_TICKER
    ).fetch_data()

[DEBUG:62 connectionpool:813 2022-01-09 21:27:01,026] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:01,207] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/AXP?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:01,386] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:01,490] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/AMGN?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:01,593] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:01,726] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/AAPL?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:01,820] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:02,005] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/BA?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:02,108] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:02,332] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/CAT?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:02,433] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:02,552] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/CSCO?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:02,665] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:02,851] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/CVX?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:02,952] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:03,084] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/GS?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:03,204] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:03,365] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/HD?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:03,468] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:03,642] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/HON?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:03,752] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:03,964] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/IBM?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:04,076] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:04,236] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/INTC?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:04,340] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:04,681] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/JNJ?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:04,789] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:04,966] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/KO?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:05,067] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:05,236] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/JPM?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:05,352] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:05,526] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/MCD?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:05,626] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:05,856] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/MMM?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:05,957] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:06,147] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/MRK?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:06,250] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:06,405] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/MSFT?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:06,503] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:06,661] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/NKE?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:06,775] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:06,974] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/PG?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:07,080] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:07,253] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/TRV?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:07,360] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:07,516] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/UNH?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:07,625] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:07,717] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/CRM?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed


[DEBUG:62 connectionpool:813 2022-01-09 21:27:07,825] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:08,043] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/VZ?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:08,141] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:08,252] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/V?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:08,352] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:08,539] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/WBA?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:08,653] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:08,831] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/WMT?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:08,938] Starting new HTTPS connection (1): query2.finance.yahoo.com:443





[DEBUG:62 connectionpool:393 2022-01-09 21:27:09,134] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/DIS?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


[*********************100%***********************]  1 of 1 completed

[DEBUG:62 connectionpool:813 2022-01-09 21:27:09,235] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:09,317] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/DOW?period1=1230768000&period2=1635638400&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None



[*********************100%***********************]  1 of 1 completed
Shape of DataFrame:  (94331, 8)


In [None]:
if data is None:
  df.shape

In [None]:
if data is None:
  df.sort_values(['date','tic'],ignore_index=True).head()

  # Part 4: Preprocess Data
  Data preprocessing is a crucial step for training a high quality machine learning model. We need to check for missing data and do feature engineering in order to convert the data into a model-ready state.
  * Add technical indicators. In practical trading, various information needs to be taken into account, for example the historical stock prices, current holding shares, technical indicators, etc. In this article, we demonstrate two trend-following technical indicators: MACD and RSI.
  * Add turbulence index. Risk-aversion reflects whether an investor will choose to preserve the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the financial turbulence index that measures extreme asset price fluctuation.

In [None]:
if data is None:
    fe = FeatureEngineer(
        use_technical_indicator=True,
        tech_indicator_list = config.TECHNICAL_INDICATORS_LIST,
        use_vix=True,
        use_turbulence=True,
        user_defined_feature = False)

    processed = fe.preprocess_data(df)

[DEBUG:62 connectionpool:813 2022-01-09 21:27:54,728] Starting new HTTPS connection (1): query2.finance.yahoo.com:443
[DEBUG:62 connectionpool:393 2022-01-09 21:27:54,837] https://query2.finance.yahoo.com:443 "GET /v8/finance/chart/%5EVIX?period1=1230854400&period2=1635465600&interval=1d&includePrePost=False&events=div%2Csplits HTTP/1.1" 200 None


Successfully added technical indicators
[*********************100%***********************]  1 of 1 completed
Shape of DataFrame:  (3229, 8)
Successfully added vix
Successfully added turbulence index


In [None]:
if data is None:
    list_ticker = processed["tic"].unique().tolist()
    list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))
    combination = list(itertools.product(list_date,list_ticker))

    processed_full = pd.DataFrame(combination,columns=["date","tic"]).merge(processed,on=["date","tic"],how="left")
    processed_full = processed_full[processed_full['date'].isin(processed['date'])]
    processed_full = processed_full.sort_values(['date','tic'])

    processed_full = processed_full.fillna(0)
    processed_full.to_csv(data_path, index=False)
    
else:
    processed_full = data
    

In [None]:
processed_full['date'] = pd.to_datetime(processed_full['date'])
processed_full.sort_values(['date','tic'],ignore_index=True).head(10)

Unnamed: 0,date,tic,open,high,low,close,volume,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,vix,turbulence
0,2009-01-02,AAPL,3.067143,3.251429,3.041429,2.778781,746015200.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,2.778781,2.778781,39.189999,0.0
1,2009-01-02,AMGN,58.59,59.080002,57.75,45.615864,6547900.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,45.615864,45.615864,39.189999,0.0
2,2009-01-02,AXP,18.57,19.52,18.4,15.579443,10955700.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,15.579443,15.579443,39.189999,0.0
3,2009-01-02,BA,42.799999,45.560001,42.779999,33.941101,7010200.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,33.941101,33.941101,39.189999,0.0
4,2009-01-02,CAT,44.91,46.98,44.709999,32.4758,7117200.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,32.4758,32.4758,39.189999,0.0
5,2009-01-02,CRM,8.025,8.55,7.9125,8.505,4069200.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,8.505,8.505,39.189999,0.0
6,2009-01-02,CSCO,16.41,17.0,16.25,12.349072,40980600.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,12.349072,12.349072,39.189999,0.0
7,2009-01-02,CVX,74.230003,77.300003,73.580002,45.650551,13695900.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,45.650551,45.650551,39.189999,0.0
8,2009-01-02,DIS,22.76,24.030001,22.5,20.597496,9796600.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,20.597496,20.597496,39.189999,0.0
9,2009-01-02,GS,84.019997,87.620003,82.190002,71.587761,14088500.0,4.0,0.0,3.003271,2.671567,100.0,66.666667,100.0,71.587761,71.587761,39.189999,0.0


  <a id='4'></a>
  # Part 5. Design Environment
  Considering the stochastic and interactive nature of the automated stock trading tasks, a financial task is modeled as a **Markov Decision Process (MDP)** problem. The training process involves observing stock price change, taking an action and reward's calculation to have the agent adjusting its strategy accordingly. By interacting with the environment, the trading agent will derive a trading strategy with the maximized rewards as time proceeds.

  Our trading environments, based on OpenAI Gym framework, simulate live stock markets with real market data according to the principle of time-driven simulation.

  The action space describes the allowed actions that the agent interacts with the environment. Normally, action a includes three actions: {-1, 0, 1}, where -1, 0, 1 represent selling, holding, and buying one share. Also, an action can be carried upon multiple shares. We use an action space {-k,…,-1, 0, 1, …, k}, where k denotes the number of shares to buy and -k denotes the number of shares to sell. For example, "Buy 10 shares of AAPL" or "Sell 10 shares of AAPL" are 10 or -10, respectively. The continuous action space needs to be normalized to [-1, 1], since the policy is defined on a Gaussian distribution, which needs to be normalized and symmetric.

  ## Training data split: 2009-01-01 to 2020-07-01
  ## Trade data split: 2020-07-01 to 2021-10-31

In [None]:
train = data_split(processed_full, '2009-01-01','2020-07-01')
trade = data_split(processed_full, '2020-07-01','2021-10-31')
print(len(train))
print(len(trade))

83897
9744


In [None]:
train.tail()

Unnamed: 0,date,tic,open,high,low,close,volume,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,vix,turbulence
2892,2020-06-30,UNH,288.570007,296.450012,287.660004,288.628418,2932900.0,1.0,-0.019532,304.825261,272.053935,52.413055,-25.815168,1.846804,288.873002,281.832984,30.43,12.918684
2892,2020-06-30,V,191.490005,193.75,190.160004,191.41243,9040100.0,1.0,1.052496,199.454099,185.696427,53.021028,-51.516799,2.013358,192.162885,182.320811,30.43,12.918684
2892,2020-06-30,VZ,54.919998,55.290001,54.360001,50.990154,17414800.0,1.0,-0.442433,54.574966,49.322677,48.097035,-50.972382,8.508886,51.633274,52.091341,30.43,12.918684
2892,2020-06-30,WBA,42.119999,42.580002,41.759998,39.89254,4782100.0,1.0,-0.085828,43.544545,37.287959,48.830188,-14.445144,1.500723,39.994176,39.789724,30.43,12.918684
2892,2020-06-30,WMT,119.220001,120.129997,118.540001,116.994629,6836400.0,1.0,-0.893233,120.371811,114.363667,48.159688,-69.914537,3.847271,118.672997,120.623193,30.43,12.918684


In [None]:
trade.head()

Unnamed: 0,date,tic,open,high,low,close,volume,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,vix,turbulence
0,2020-07-01,AAPL,91.279999,91.839996,90.977501,90.151405,110737200.0,2.0,3.022879,92.953798,80.40006,62.807126,107.487537,29.730532,84.164181,77.930892,28.620001,53.068271
0,2020-07-01,AMGN,235.520004,256.230011,232.580002,244.159134,6575800.0,2.0,3.697037,236.273228,203.551995,61.279641,271.769248,46.806139,218.441967,219.532867,28.620001,53.068271
0,2020-07-01,AXP,95.25,96.959999,93.639999,92.347763,3301000.0,2.0,-0.390268,110.737393,88.008444,48.50481,-66.334579,3.142448,97.520675,90.952973,28.620001,53.068271
0,2020-07-01,BA,185.880005,190.610001,180.039993,180.320007,49036700.0,2.0,5.443193,220.721139,160.932863,50.925771,24.220608,15.93292,176.472335,155.614168,28.620001,53.068271
0,2020-07-01,CAT,129.380005,129.399994,125.879997,121.818474,2807800.0,2.0,1.284936,131.887567,114.449385,52.865416,35.546899,14.457404,120.567698,114.745772,28.620001,53.068271


In [None]:
config.TECHNICAL_INDICATORS_LIST

['macd',
 'boll_ub',
 'boll_lb',
 'rsi_30',
 'cci_30',
 'dx_30',
 'close_30_sma',
 'close_60_sma']

In [None]:
stock_dimension = len(train.tic.unique())
state_space = 1 + 2*stock_dimension + len(config.TECHNICAL_INDICATORS_LIST)*stock_dimension
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")

Stock Dimension: 29, State Space: 291


  ## Environment for Training



In [None]:
env_kwargs = {
    "hmax": 100, 
    "initial_amount": 1000000, 
    "buy_cost_pct": 0.001,
    "sell_cost_pct": 0.001,
    "state_space": state_space, 
    "stock_dim": stock_dimension, 
    "tech_indicator_list": config.TECHNICAL_INDICATORS_LIST, 
    "action_space": stock_dimension, 
    "reward_scaling": 1e-4
    
}

e_train_gym = StockTradingEnv(df = train, **env_kwargs)

  ## Trading
  Assume that we have $1,000,000 initial capital at 2020-07-01. We use the DDPG model to trade Dow jones 30 stocks.

  ### Set turbulence threshold
  Set the turbulence threshold to be greater than the maximum of insample turbulence data, if current turbulence index is greater than the threshold, then we assume that the current market is volatile

In [None]:
data_risk_indicator = processed_full[(processed_full.date<'2020-07-01') & (processed_full.date>='2009-01-01')]
insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])

In [None]:
insample_risk_indicator.vix.describe()

count    2893.000000
mean       18.824245
std         8.489311
min         9.140000
25%        13.330000
50%        16.139999
75%        21.309999
max        82.690002
Name: vix, dtype: float64

In [None]:
insample_risk_indicator.vix.quantile(0.996)

57.40400183105453

In [None]:
insample_risk_indicator.turbulence.describe()

count    2893.000000
mean       34.567955
std        43.790795
min         0.000000
25%        14.962946
50%        24.124388
75%        39.162099
max       652.504095
Name: turbulence, dtype: float64

In [None]:
insample_risk_indicator.turbulence.quantile(0.996)

276.44975706280064

In [None]:
e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)

 # Part 6: Implement DRL Algorithms

 ## Training

In [None]:
impala.set_env(e_train_gym)
impala.train(FLAGS)

TypeError: ignored

  ### Trade

  DRL model needs to update periodically in order to take full advantage of the data, ideally we need to retrain our model yearly, quarterly, or monthly. We also need to tune the parameters along the way, in this notebook I only use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends.

  Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations.

In [None]:
impala.set_env(e_trade_gym)
df_account_value, df_actions = impala.test(FLAGS)

In [None]:
df_account_value.shape

In [None]:
df_account_value.tail()

  <a id='6'></a>
  # Part 7: Backtest Our Strategy
  Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy.

  <a id='6.1'></a>
  ## 7.1 BackTestStats
  pass in df_account_value, this information is stored in env class


In [None]:
print("==============Get Backtest Results===========")
now = datetime.datetime.now().strftime('%Y%m%d-%Hh%M')

perf_stats_all = backtest_stats(account_value=df_account_value)
perf_stats_all = pd.DataFrame(perf_stats_all)
perf_stats_all.to_csv("./"+config.RESULTS_DIR+"/perf_stats_all_"+now+'.csv')

In [None]:
#baseline stats
print("==============Get Baseline Stats===========")
baseline_df = get_baseline(
        ticker="^DJI", 
        start = df_account_value.loc[0,'date'],
        end = df_account_value.loc[len(df_account_value)-1,'date'])

stats = backtest_stats(baseline_df, value_col_name = 'close')

In [None]:
df_account_value.loc[0,'date']

In [None]:
df_account_value.loc[len(df_account_value)-1,'date']

  <a id='6.2'></a>
  ## 7.2 BackTestPlot

In [None]:
print("==============Compare to DJIA===========")

# S&P 500: ^GSPC
# Dow Jones Index: ^DJI
# NASDAQ 100: ^NDX
backtest_plot(df_account_value, 
             baseline_ticker = '^DJI', 
             baseline_start = df_account_value.loc[0,'date'],
             baseline_end = df_account_value.loc[len(df_account_value)-1,'date'])

 <a id='6.3'></a>
 ## 7.3 TransactionPlot

In [None]:
def trx_plot(df_trade, df_actions, tics=None):
    """Plot transactions."""
    import matplotlib.dates as mdates

    df_trx = df_actions

    if tics is None:
        tics = list(df_trx)

    for tic in tics:
        df_trx_temp = df_trx[tic]
        df_trx_temp_sign = np.sign(df_trx_temp)
        buying_signal = df_trx_temp_sign.apply(lambda x: True if x > 0 else False)
        selling_signal = df_trx_temp_sign.apply(lambda x: True if x < 0 else False)

        tic_plot = df_trade[
            (df_trade["tic"] == df_trx_temp.name)
            & (df_trade["date"].isin(df_trx.index))
        ]["close"]
        tic_plot.index = df_trx_temp.index

        plt.figure(figsize=(10, 8))
        plt.plot(tic_plot, color="g", lw=2.0)
        plt.plot(
            tic_plot,
            "^",
            markersize=10,
            color="m",
            label="buying signal",
            markevery=buying_signal,
        )
        plt.plot(
            tic_plot,
            "v",
            markersize=10,
            color="k",
            label="selling signal",
            markevery=selling_signal,
        )
        plt.title(
            f"{df_trx_temp.name} Num Transactions: {len(buying_signal[buying_signal==True]) + len(selling_signal[selling_signal==True])}"
        )
        plt.legend()
        plt.gca().xaxis.set_major_locator(mdates.DayLocator(interval=25))
        plt.xticks(rotation=45, ha="right")
        plt.show()

trx_plot(trade, df_actions)