### Imports

In [1]:
import numpy as np
import pandas as pd

import gym
import gym_anytrading
import quantstats as qs

import trading.bars_db as bars_db

from stable_baselines3 import A2C
from stable_baselines3.common.vec_env import DummyVecEnv

import matplotlib.pyplot as plt
plt.rcParams['font.family'] = 'DeJavu Serif'
plt.rcParams['font.serif'] = ['Times New Roman']

### Create Env

In [None]:
stock_df = gym_anytrading.datasets.STOCKS_GOOGL.copy()
stock_df.head()

In [None]:
forex_df = gym_anytrading.datasets.FOREX_EURUSD_1H_ASK.copy()
forex_df.head()

In [2]:
gen_df = bars_db.load_from_parquet("cryptocompare", "h1")
btc_usd_df = bars_db.get_df_for_symbols(gen_df, "BTC", "USD", "h1")
df = bars_db.prepare_for_gym_anytrading(btc_usd_df)
df.head()

Unnamed: 0_level_0,Open,Close,High,Low,Volume
Time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
2016-01-01 00:00:00+00:00,430.08,431.62,432.27,429.2,1069.97
2016-01-01 01:00:00+00:00,431.62,430.06,431.71,429.52,859.6
2016-01-01 02:00:00+00:00,430.06,430.38,432.3,429.43,918.8
2016-01-01 03:00:00+00:00,430.38,431.2,431.73,429.85,1021.05
2016-01-01 04:00:00+00:00,431.2,435.53,436.56,430.64,5493.42


In [None]:
window_size = 10
start_index = window_size
end_index = len(df)

env_maker = lambda: gym.make(
    'crypto-v0',
    df = df,
    window_size = window_size,
    frame_bound = (start_index, end_index)
)

env = DummyVecEnv([env_maker])

### Train Env

In [None]:
policy_kwargs = dict(net_arch=[64, dict(vf=[128, 128, 128], pi=[64, 64])])
model = A2C('MlpPolicy', env, verbose=1, policy_kwargs=policy_kwargs)
model.learn(total_timesteps=10000)

### Test Env

In [None]:
env = env_maker()
observation = env.reset()

while True:
    observation = observation[np.newaxis, ...]

    # action = env.action_space.sample()
    action, _states = model.predict(observation)
    observation, reward, done, info = env.step(action)

    # env.render()
    if done:
        print("info:", info)
        break

### Plot Results

In [None]:
plt.figure(figsize=(16, 6))
env.render_all()
plt.show()

### Analysis Using `quantstats`

In [None]:
qs.extend_pandas()

net_worth = pd.Series(env.history['total_profit'], index=df.index[start_index+1:end_index])
returns = net_worth.pct_change().iloc[1:]

len(env.history['total_profit'])

In [None]:
len(df.index)

In [None]:
qs.reports.full(returns, font_name="DejaVu Sans")
#qs.reports.html(returns, output='a2c_quantstats.html')