In [1]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
from utils.utilities import fetch_data
from utils.envs import TradingEnv9
from utils.td3 import Agent
from utils.pred import Predictor

In [3]:
import warnings

warnings.simplefilter('ignore')

In [4]:
%matplotlib inline

---

In [5]:
env_name = 'TradingEnv9'
file_name = 'TD3_TradingEnv9_main_42'
db_name = './data/HistoricalPriceData.db'

seed = 101

In [6]:
torch.manual_seed(seed)
np.random.seed(seed)

In [7]:
START = pd.to_datetime('2018-12-01')

In [9]:
file_name

'TD3_TradingEnv9_main_42'

---
---
---

##### Fetch the data

In [10]:
data = fetch_data(db_name)

In [None]:
for tick in data:

    data[tick] = data[tick][
        data[tick]['date'] >= HOLDOUT
    ]
    
    data[tick].reset_index(drop=True, inplace=True)
    

##### Initialize environment and set seeds

In [None]:
env = eval(f'{env_name}(data)')
env.seed(seed)

In [None]:
state_dim = np.prod(env.observation_space.shape)
action_dim = env.action_space.shape[0]
max_action = float(env.action_space.high[0])

##### Initialize agent

In [None]:
agent_state_dim = env.observation_space.shape[0] * (env.observation_space.shape[1]+1)

In [None]:
agent = Agent(
    state_dim=agent_state_dim, 
    action_dim=action_dim, 
    max_action=max_action,
)


In [None]:
agent.load(file_name, './models')

##### Initialize predictor model

In [None]:
pred_input = (env.observation_space.shape[0]-1) * (env.observation_space.shape[1]-2)
pred_output = env.action_space.shape[0] - 1

In [None]:
predictor = Predictor(
    input_dim=pred_input, 
    output_dim=pred_output,
)


In [None]:
predictor.load(file_name, './models')

##### Test:

In [None]:
positions = env.positions
reward_trace = collections.defaultdict(list)

In [None]:
for episode in np.arange(1):

    done = False
    obs = env.reset()

    while not done:
            
        agent_obs = predictor.predict(obs)

        action = agent.select_action(agent_obs)
        action_fmt = env.format_action(positions, action)

        new_obs, reward, done, info = env.step(action_fmt)
        
        obs = new_obs

        env.render()
    
        reward_trace['rewards'].append(reward)
        reward_trace['actions'].append(actions.squeeze())
        reward_trace['net_worth_diff'].append(env.net_worth-env.net_worth_long)
    