# First Glance：第一个程序，完整地跑一遍流程。
* **由于官方的是yahoo，国内跑不了，因此把它改成了tushare**
* **改编自 Deep Reinforcement Learning for Stock Trading from Scratch: Multiple Stock Trading**
* **Pytorch Version** 



# Content

* [1. Task Description](#0)
* [2. Install Python packages](#1)
    * [2.1. Install Packages](#1.1)    
    * [2.2. A List of Python Packages](#1.2)
    * [2.3. Import Packages](#1.3)
    * [2.4. Create Folders](#1.4)
* [3. Download and Preprocess Data](#2)
* [4. Preprocess Data](#3)        
    * [4.1. Technical Indicators](#3.1)
    * [4.2. Perform Feature Engineering](#3.2)
* [5. Build Market Environment in OpenAI Gym-style](#4)  
    * [5.1. Data Split](#4.1)  
    * [5.3. Environment for Training](#4.2)    
* [6. Train DRL Agents](#5)
* [7. Backtesting Performance](#6)  
    * [7.1. BackTestStats](#6.1)
    * [7.2. BackTestPlot](#6.2)   
  

<a id='0'></a>
# Part 1. Task Discription

We train a DRL agent for stock trading. This task is modeled as a Markov Decision Process (MDP), and the objective function is maximizing (expected) cumulative return.

We specify the state-action-reward as follows:

* **State s**: The state space represents an agent's perception of the market environment. Just like a human trader analyzing various information, here our agent passively observes many features and learns by interacting with the market environment (usually by replaying historical data).

* **Action a**: The action space includes allowed actions that an agent can take at each state. For example, a ∈ {−1, 0, 1}, where −1, 0, 1 represent
selling, holding, and buying. When an action operates multiple shares, a ∈{−k, ..., −1, 0, 1, ..., k}, e.g.. "Buy
10 shares of AAPL" or "Sell 10 shares of AAPL" are 10 or −10, respectively

* **Reward function r(s, a, s′)**: Reward is an incentive for an agent to learn a better policy. For example, it can be the change of the portfolio value when taking a at state s and arriving at new state s',  i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio values at state s′ and s, respectively


**Market environment**: 30 consituent stocks of Dow Jones Industrial Average (DJIA) index. Accessed at the starting date of the testing period.


The data for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.


<a id='1'></a>
# Part 2. Install Python Packages

<a id='1.1'></a>
## 2.1. Install packages


In [None]:
## install finrl library
#!pip install tushare  -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn
#!pip install baostock -i https://pypi.tuna.tsinghua.edu.cn/simple/ --trusted-host pypi.tuna.tsinghua.edu.cn


<a id='1.2'></a>
## 2.2. A list of Python packages 
* Yahoo Finance API
* pandas
* numpy
* matplotlib
* stockstats
* OpenAI gym
* stable-baselines
* tensorflow
* pyfolio

<a id='1.3'></a>
## 2.3. Import Packages

In [4]:
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# matplotlib.use('Agg')
import datetime
import sys
sys.path.append("../../FinRL")

%matplotlib inline
from finrl.meta.preprocessor.yahoodownloader import YahooDownloader
from finrl.meta.preprocessor.tusharedownloader import TushareDownloader
from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split
from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv
from finrl.agents.stablebaselines3.models import DRLAgent
from stable_baselines3.common.logger import configure
from finrl.meta.data_processor import DataProcessor

from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline
from pprint import pprint



import itertools



<a id='1.4'></a>
## 2.4. Create Folders

In [5]:
from finrl import config
from finrl import config_tickers
import os
from finrl.main import check_and_make_directories
from finrl.config import (
    DATA_SAVE_DIR,
    TRAINED_MODEL_DIR,
    TENSORBOARD_LOG_DIR,
    RESULTS_DIR,
    INDICATORS,
    TRAIN_START_DATE,
    TRAIN_END_DATE,
    TEST_START_DATE,
    TEST_END_DATE,
    TRADE_START_DATE,
    TRADE_END_DATE,
)
check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])



<a id='2'></a>
# Part 3. Download Data
Yahoo Finance provides stock data, financial news, financial reports, etc. Yahoo Finance is free.
* FinRL uses a class **YahooDownloader** in FinRL-Meta to fetch data via Yahoo Finance API
* Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day).



-----
class YahooDownloader:
    Retrieving daily stock data from
    Yahoo Finance API

    Attributes
    ----------
        start_date : str
            start date of the data (modified from config.py)
        end_date : str
            end date of the data (modified from config.py)
        ticker_list : list
            a list of stock tickers (modified from config.py)

    Methods
    -------
    fetch_data()


In [6]:
# from config.py, TRAIN_START_DATE is a string
TRAIN_START_DATE
# from config.py, TRAIN_END_DATE is a string
TRAIN_END_DATE

'2020-07-31'

In [7]:
TRAIN_START_DATE = '2020-07-22'
TRAIN_END_DATE = '2021-01-01'
TRADE_START_DATE = '2021-01-01'
TRADE_END_DATE = '2021-10-31'


In [8]:
df = TushareDownloader(start_date = TRAIN_START_DATE,
                     end_date = TRADE_END_DATE,
                     ticker_list = config_tickers.SSE_50_TICKER).fetch_data()

  0%|          | 0/62 [00:00<?, ?it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
  3%|▎         | 2/62 [00:00<00:13,  4.47it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
  6%|▋         | 4/62 [00:00<00:12,  4.76it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 10%|▉         | 6/62 [00:01<00:10,  5.16it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 13%|█▎        | 8/62 [00:01<00:09,  5.46it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 16%|█▌        | 10/62 [00:01<00:09,  5.56it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 19%|█▉        | 12/62 [00:02<00:09,  5.55it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)


本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
 24%|██▍       | 15/62 [00:02<00:08,  5.36it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 27%|██▋       | 17/62 [00:03<00:08,  5.35it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)


本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
 32%|███▏      | 20/62 [00:03<00:08,  5.19it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
 34%|███▍      | 21/62 [00:04<00:08,  4.97it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 37%|███▋      | 23/62 [00:04<00:07,  5.05it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 40%|████      | 25/62 [00:04<00:07,  5.20it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 44%|████▎     | 27/62 [00:05<00:06,  5.33it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 47%|████▋     | 29/62 [00:05<00:06,  5.34it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
 48%|████▊     | 30/62 [00:05<00:06,  5.26it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)


本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


 52%|█████▏    | 32/62 [00:06<00:06,  4.99it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 55%|█████▍    | 34/62 [00:06<00:05,  4.93it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 58%|█████▊    | 36/62 [00:06<00:05,  5.17it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 61%|██████▏   | 38/62 [00:07<00:04,  5.26it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 65%|██████▍   | 40/62 [00:07<00:03,  5.50it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 68%|██████▊   | 42/62 [00:08<00:03,  5.69it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 71%|███████   | 44/62 [00:08<00:03,  5.33it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 74%|███████▍  | 46/62 [00:08<00:03,  5.09it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
 76%|███████▌  | 47/62 [00:09<00:02,  5.33it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 79%|███████▉  | 49/62 [00:09<00:02,  4.96it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 82%|████████▏ | 51/62 [00:09<00:02,  5.34it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 85%|████████▌ | 53/62 [00:10<00:01,  5.37it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 89%|████████▊ | 55/62 [00:10<00:01,  5.51it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 92%|█████████▏| 57/62 [00:10<00:00,  5.47it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 95%|█████████▌| 59/62 [00:11<00:00,  5.40it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
  data_df = data_df.append(temp_df)
 98%|█████████▊| 61/62 [00:11<00:00,  5.42it/s]

本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2
本接口即将停止更新，请尽快使用Pro版接口：https://tushare.pro/document/2


  data_df = data_df.append(temp_df)
100%|██████████| 62/62 [00:11<00:00,  5.23it/s]

Shape of DataFrame:  (19121, 9)
             date    open    high   close     low     volume  turnover  \
0      2020-07-22  213.00  215.82  212.66  210.55   65675.81      1.46   
1      2020-07-23  209.98  212.50  208.60  205.80   60449.68      1.34   
2      2020-07-24  208.00  215.01  206.11  203.88   72763.22      1.61   
3      2020-07-27  207.65  209.20  204.18  202.00   43012.66      0.95   
4      2020-07-28  206.04  207.50  204.96  203.13   34134.48      0.75   
...           ...     ...     ...     ...     ...        ...       ...   
19116  2021-10-25    9.03    9.06    9.03    9.02  265855.16      0.09   
19117  2021-10-26    9.06    9.09    9.03    9.01  307518.28      0.10   
19118  2021-10-27    9.01    9.02    8.99    8.96  341972.81      0.12   
19119  2021-10-28    8.99    9.01    8.96    8.95  281651.81      0.10   
19120  2021-10-29    8.96    9.00    8.94    8.93  359907.31      0.12   

          tic  day  
0      603160    2  
1      603160    3  
2      603160   


  data_df = data_df.drop(


In [9]:
print(config_tickers.SSE_50_TICKER)

['600000.XSHG', '600036.XSHG', '600104.XSHG', '600030.XSHG', '601628.XSHG', '601166.XSHG', '601318.XSHG', '601328.XSHG', '601088.XSHG', '601857.XSHG', '601601.XSHG', '601668.XSHG', '601288.XSHG', '601818.XSHG', '601989.XSHG', '601398.XSHG', '600048.XSHG', '600028.XSHG', '600050.XSHG', '600519.XSHG', '600016.XSHG', '600887.XSHG', '601688.XSHG', '601186.XSHG', '601988.XSHG', '601211.XSHG', '601336.XSHG', '600309.XSHG', '603993.XSHG', '600690.XSHG', '600276.XSHG', '600703.XSHG', '600585.XSHG', '603259.XSHG', '601888.XSHG', '601138.XSHG', '600196.XSHG', '601766.XSHG', '600340.XSHG', '601390.XSHG', '601939.XSHG', '601111.XSHG', '600029.XSHG', '600019.XSHG', '601229.XSHG', '601800.XSHG', '600547.XSHG', '601006.XSHG', '601360.XSHG', '600606.XSHG', '601319.XSHG', '600837.XSHG', '600031.XSHG', '601066.XSHG', '600009.XSHG', '601236.XSHG', '601012.XSHG', '600745.XSHG', '600588.XSHG', '601658.XSHG', '601816.XSHG', '603160.XSHG']


In [10]:
df.shape

(19121, 9)

In [11]:
df.sort_values(['date','tic'],ignore_index=True).head()

Unnamed: 0,date,open,high,close,low,volume,turnover,tic,day
0,2020-07-22,11.55,11.75,11.62,11.5,986876.38,0.35,600000,2
1,2020-07-22,73.5,73.87,72.0,71.76,210526.88,1.93,600009,2
2,2020-07-22,5.67,5.72,5.64,5.63,1487994.12,0.42,600016,2
3,2020-07-22,5.16,5.17,5.11,5.08,748670.12,0.34,600019,2
4,2020-07-22,4.05,4.13,4.08,4.03,2547167.75,0.27,600028,2


# Part 4: Preprocess Data
We need to check for missing data and do feature engineering to convert the data point into a state.
* **Adding technical indicators**. In practical trading, various information needs to be taken into account, such as historical prices, current holding shares, technical indicators, etc. Here, we demonstrate two trend-following technical indicators: MACD and RSI.
* **Adding turbulence index**. Risk-aversion reflects whether an investor prefers to protect the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the turbulence index that measures extreme fluctuation of asset price.

In [12]:
fe = FeatureEngineer(
                    use_technical_indicator=True,
                    tech_indicator_list = INDICATORS,
                    use_vix=False,
                    use_turbulence=True,
                    user_defined_feature = False)

processed = fe.preprocess_data(df)

  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_df = indicator_df.append(
  indicator_

Successfully added technical indicators
Successfully added turbulence index


In [13]:
list_ticker = processed["tic"].unique().tolist()
list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))
combination = list(itertools.product(list_date,list_ticker))

processed_full = pd.DataFrame(combination,columns=["date","tic"]).merge(processed,on=["date","tic"],how="left")
processed_full = processed_full[processed_full['date'].isin(processed['date'])]
processed_full = processed_full.sort_values(['date','tic'])

processed_full = processed_full.fillna(0)

In [14]:
processed_full.sort_values(['date','tic'],ignore_index=True).head(10)

Unnamed: 0,date,tic,open,high,close,low,volume,turnover,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,turbulence
0,2020-07-22,600000,11.55,11.75,11.62,11.5,986876.38,0.35,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,11.62,11.62,0.0
1,2020-07-22,600016,5.67,5.72,5.64,5.63,1487994.12,0.42,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,5.64,5.64,0.0
2,2020-07-22,600019,5.16,5.17,5.11,5.08,748670.12,0.34,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,5.11,5.11,0.0
3,2020-07-22,600028,4.05,4.13,4.08,4.03,2547167.75,0.27,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,4.08,4.08,0.0
4,2020-07-22,600029,5.56,5.78,5.68,5.55,1150432.75,1.42,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,5.68,5.68,0.0
5,2020-07-22,600030,30.45,31.76,30.5,30.16,3460703.75,3.53,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,30.5,30.5,0.0
6,2020-07-22,600031,22.45,22.55,21.79,21.61,1350300.0,1.6,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,21.79,21.79,0.0
7,2020-07-22,600036,36.78,37.1,36.44,36.26,998811.81,0.48,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,36.44,36.44,0.0
8,2020-07-22,600048,16.51,16.89,16.54,16.44,1047299.75,0.88,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,16.54,16.54,0.0
9,2020-07-22,600050,5.26,5.3,5.24,5.21,1675865.38,0.78,2.0,0.0,12.333087,10.126913,0.0,-66.666667,100.0,5.24,5.24,0.0


<a id='4'></a>
# Part 5. Build A Market Environment in OpenAI Gym-style
The training process involves observing stock price change, taking an action and reward's calculation. By interacting with the market environment, the agent will eventually derive a trading strategy that may maximize (expected) rewards.

Our market environment, based on OpenAI Gym, simulates stock markets with historical market data.

## Data Split
We split the data into training set and testing set as follows:

Training data period: 2009-01-01 to 2020-07-01

Trading data period: 2020-07-01 to 2021-10-31


In [15]:
train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE)
trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE)
print(len(train))
print(len(trade))

6216
11088


In [16]:
train.tail()

Unnamed: 0,date,tic,open,high,close,low,volume,turnover,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,turbulence
110,2020-12-31,601988,3.16,3.19,3.18,3.16,1109060.62,0.05,3.0,-0.019428,3.261187,3.129813,44.429169,-69.643401,0.676494,3.225667,3.222333,0.0
110,2020-12-31,601989,4.08,4.19,4.19,4.07,1567744.38,0.86,3.0,-0.043229,4.296003,3.953997,46.798216,-26.811686,7.639092,4.196667,4.241833,0.0
110,2020-12-31,603160,156.01,156.98,155.55,153.53,41837.67,0.92,3.0,-4.690169,167.483387,140.566613,43.767288,-33.148016,14.37835,159.046,166.7935,0.0
110,2020-12-31,603259,131.85,136.55,134.72,131.85,150642.12,1.07,3.0,5.759525,137.639091,103.130909,67.204283,144.603498,56.233451,115.576667,113.540833,0.0
110,2020-12-31,603993,5.83,6.38,6.25,5.81,7357855.5,4.17,3.0,0.369526,6.125415,3.885585,69.228686,251.231527,70.811246,4.888667,4.499,0.0


In [17]:
trade.head()

Unnamed: 0,date,tic,open,high,close,low,volume,turnover,day,macd,boll_ub,boll_lb,rsi_30,cci_30,dx_30,close_30_sma,close_60_sma,turbulence
0,2021-01-04,600000,9.64,9.73,9.69,9.55,629069.38,0.21,0.0,-0.055067,9.989325,9.415675,47.80706,-57.628649,0.553621,9.827,9.6695,0.0
0,2021-01-04,600016,5.21,5.22,5.2,5.16,862805.12,0.24,0.0,-0.024029,5.333192,5.074808,46.432949,-45.34005,6.4935,5.237667,5.252,0.0
0,2021-01-04,600019,5.96,6.13,6.09,5.9,1029537.0,0.46,0.0,-0.022786,6.225068,5.800932,53.384355,-33.732318,18.850172,6.116667,5.924,0.0
0,2021-01-04,600028,4.03,4.05,4.03,3.99,1100524.38,0.12,0.0,-0.022261,4.133346,3.955654,48.815865,-67.240748,11.970337,4.102333,4.0385,0.0
0,2021-01-04,600029,5.9,5.94,5.9,5.84,445395.34,0.55,0.0,-0.081066,6.296693,5.664307,47.826087,-73.479853,5.853039,6.114333,5.9965,0.0


In [18]:
INDICATORS

['macd',
 'boll_ub',
 'boll_lb',
 'rsi_30',
 'cci_30',
 'dx_30',
 'close_30_sma',
 'close_60_sma']

In [19]:
stock_dimension = len(train.tic.unique())
state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension
print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")


Stock Dimension: 56, State Space: 561


In [20]:
buy_cost_list = sell_cost_list = [0.001] * stock_dimension
num_stock_shares = [0] * stock_dimension

env_kwargs = {
    "hmax": 100,
    "initial_amount": 1000000,
    "num_stock_shares": num_stock_shares,
    "buy_cost_pct": buy_cost_list,
    "sell_cost_pct": sell_cost_list,
    "state_space": state_space,
    "stock_dim": stock_dimension,
    "tech_indicator_list": INDICATORS,
    "action_space": stock_dimension,
    "reward_scaling": 1e-4
}


e_train_gym = StockTradingEnv(df = train, **env_kwargs)


## Environment for Training



In [21]:
env_train, _ = e_train_gym.get_sb_env()
print(type(env_train))

<class 'stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv'>


<a id='5'></a>
# Part 6: Train DRL Agents
* The DRL algorithms are from **Stable Baselines 3**. Users are also encouraged to try **ElegantRL** and **Ray RLlib**.
* FinRL includes fine-tuned standard DRL algorithms, such as DQN, DDPG, Multi-Agent DDPG, PPO, SAC, A2C and TD3. We also allow users to
design their own DRL algorithms by adapting these DRL algorithms.

In [22]:
agent = DRLAgent(env = env_train)

if_using_a2c = False
if_using_ddpg = False
if_using_ppo = False
if_using_td3 = False
if_using_sac = True


### Agent Training: 5 algorithms (A2C, DDPG, PPO, TD3, SAC)


### Agent 1: A2C


In [23]:
agent = DRLAgent(env = env_train)
model_a2c = agent.get_model("a2c")

if if_using_a2c:
  # set up logger
  tmp_path = RESULTS_DIR + '/a2c'
  new_logger_a2c = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_a2c.set_logger(new_logger_a2c)


{'n_steps': 5, 'ent_coef': 0.01, 'learning_rate': 0.0007}
Using cuda device


In [24]:
trained_a2c = agent.train_model(model=model_a2c, 
                             tb_log_name='a2c',
                             total_timesteps=50000) if if_using_a2c else None

### Agent 2: DDPG

In [25]:
agent = DRLAgent(env = env_train)
model_ddpg = agent.get_model("ddpg")

if if_using_ddpg:
  # set up logger
  tmp_path = RESULTS_DIR + '/ddpg'
  new_logger_ddpg = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_ddpg.set_logger(new_logger_ddpg)

{'batch_size': 128, 'buffer_size': 50000, 'learning_rate': 0.001}
Using cuda device


In [26]:
trained_ddpg = agent.train_model(model=model_ddpg, 
                             tb_log_name='ddpg',
                             total_timesteps=50000) if if_using_ddpg else None

### Agent 3: PPO

In [27]:
agent = DRLAgent(env = env_train)
PPO_PARAMS = {
    "n_steps": 2048,
    "ent_coef": 0.01,
    "learning_rate": 0.00025,
    "batch_size": 128,
}
model_ppo = agent.get_model("ppo",model_kwargs = PPO_PARAMS)

if if_using_ppo:
  # set up logger
  tmp_path = RESULTS_DIR + '/ppo'
  new_logger_ppo = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_ppo.set_logger(new_logger_ppo)

{'n_steps': 2048, 'ent_coef': 0.01, 'learning_rate': 0.00025, 'batch_size': 128}
Using cuda device


In [28]:
trained_ppo = agent.train_model(model=model_ppo, 
                             tb_log_name='ppo',
                             total_timesteps=50000) if if_using_ppo else None

### Agent 4: TD3

In [29]:
agent = DRLAgent(env = env_train)
TD3_PARAMS = {"batch_size": 100, 
              "buffer_size": 1000000, 
              "learning_rate": 0.001}

model_td3 = agent.get_model("td3",model_kwargs = TD3_PARAMS)

if if_using_td3:
  # set up logger
  tmp_path = RESULTS_DIR + '/td3'
  new_logger_td3 = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_td3.set_logger(new_logger_td3)

{'batch_size': 100, 'buffer_size': 1000000, 'learning_rate': 0.001}
Using cuda device


In [30]:
trained_td3 = agent.train_model(model=model_td3, 
                             tb_log_name='td3',
                             total_timesteps=30000) if if_using_td3 else None

### Agent 5: SAC

In [31]:
agent = DRLAgent(env = env_train)
SAC_PARAMS = {
    "batch_size": 128,
    "buffer_size": 100000,
    "learning_rate": 0.0001,
    "learning_starts": 100,
    "ent_coef": "auto_0.1",
}

model_sac = agent.get_model("sac",model_kwargs = SAC_PARAMS)

if if_using_sac:
  # set up logger
  tmp_path = RESULTS_DIR + '/sac'
  new_logger_sac = configure(tmp_path, ["stdout", "csv", "tensorboard"])
  # Set new logger
  model_sac.set_logger(new_logger_sac)

{'batch_size': 128, 'buffer_size': 100000, 'learning_rate': 0.0001, 'learning_starts': 100, 'ent_coef': 'auto_0.1'}
Using cuda device
Logging to results/sac


In [32]:
trained_sac = agent.train_model(model=model_sac, 
                             tb_log_name='sac',
                             total_timesteps=40000) if if_using_sac else None

----------------------------------
| time/              |           |
|    episodes        | 4         |
|    fps             | 60        |
|    time_elapsed    | 7         |
|    total_timesteps | 444       |
| train/             |           |
|    actor_loss      | -35.1     |
|    critic_loss     | 2.92e+03  |
|    ent_coef        | 0.102     |
|    ent_coef_loss   | 783       |
|    learning_rate   | 0.0001    |
|    n_updates       | 343       |
|    reward          | 2.0817602 |
----------------------------------
----------------------------------
| time/              |           |
|    episodes        | 8         |
|    fps             | 57        |
|    time_elapsed    | 15        |
|    total_timesteps | 888       |
| train/             |           |
|    actor_loss      | 44.6      |
|    critic_loss     | 1.07e+03  |
|    ent_coef        | 0.105     |
|    ent_coef_loss   | 244       |
|    learning_rate   | 0.0001    |
|    n_updates       | 787       |
|    reward         

## In-sample Performance

Assume that the initial capital is $1,000,000.

### Set turbulence threshold
Set the turbulence threshold to be greater than the maximum of insample turbulence data. If current turbulence index is greater than the threshold, then we assume that the current market is volatile

In [None]:
data_risk_indicator = processed_full[(processed_full.date<TRAIN_END_DATE) & (processed_full.date>=TRAIN_START_DATE)]
insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])

In [None]:
insample_risk_indicator.vix.describe()

In [None]:
insample_risk_indicator.vix.quantile(0.996)

In [None]:
insample_risk_indicator.turbulence.describe()

In [None]:
insample_risk_indicator.turbulence.quantile(0.996)

### Trading (Out-of-sample Performance)

We update periodically in order to take full advantage of the data, e.g., retrain quarterly, monthly or weekly. We also tune the parameters along the way, in this notebook we use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends. 

Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations.

In [None]:
e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)
# env_trade, obs_trade = e_trade_gym.get_sb_env()

In [None]:
trade.head()

In [None]:
trained_moedl = trained_sac
df_account_value, df_actions = DRLAgent.DRL_prediction(
    model=trained_moedl, 
    environment = e_trade_gym)

In [None]:
df_account_value.shape

In [None]:
df_account_value.tail()

In [None]:
df_actions.head()

<a id='6'></a>
# Part 7: Backtesting Results
Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy.

<a id='6.1'></a>
## 7.1 BackTestStats
pass in df_account_value, this information is stored in env class


In [None]:
print("==============Get Backtest Results===========")
now = datetime.datetime.now().strftime('%Y%m%d-%Hh%M')

perf_stats_all = backtest_stats(account_value=df_account_value)
perf_stats_all = pd.DataFrame(perf_stats_all)
perf_stats_all.to_csv("./"+RESULTS_DIR+"/perf_stats_all_"+now+'.csv')

In [None]:
#baseline stats
print("==============Get Baseline Stats===========")
baseline_df = get_baseline(
        ticker="^DJI", 
        start = df_account_value.loc[0,'date'],
        end = df_account_value.loc[len(df_account_value)-1,'date'])

stats = backtest_stats(baseline_df, value_col_name = 'close')


In [None]:
df_account_value.loc[0,'date']

In [None]:
df_account_value.loc[len(df_account_value)-1,'date']

<a id='6.2'></a>
## 7.2 BackTestPlot

In [None]:
print("==============Compare to DJIA===========")
%matplotlib inline
# S&P 500: ^GSPC
# Dow Jones Index: ^DJI
# NASDAQ 100: ^NDX
backtest_plot(df_account_value, 
             baseline_ticker = '^DJI', 
             baseline_start = df_account_value.loc[0,'date'],
             baseline_end = df_account_value.loc[len(df_account_value)-1,'date'])