In [1]:
import numpy as np
import pandas as pd

# Helper classes for data processing and state representation
from reinforcetrader.data_pipeline import RawDataLoader, FeatureBuilder
from reinforcetrader.state import EpisodeStateLoader
from reinforcetrader.dqn_agent import RLAgent

# Helper method for displaying large dataframes
from IPython.display import HTML
pd.set_option("display.max_rows", None)
from IPython.display import HTML, display

def display_df(df, rows=10):
    # Puts the scrollbar next to the DataFrame
    display(HTML("<div style='height: 300px; overflow: auto; width: 98%'>" + df.head(rows).to_html() + "</div>"))

2025-08-22 14:00:04.317286: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


# Data Pre-processing and Feature Building

In [None]:
# Get DJI raw data using Yahoo Finance API
data_loader = RawDataLoader(start_date='2000-02-20', end_date='2025-08-11')

# Download or load from cache the data
raw_data = data_loader.get_hist_prices()

# Inspect the raw DJI close and volume data
display_df(raw_data)

In [None]:
# Instantiate Feature Builder to Build relevant features for state space
feature_builder = FeatureBuilder(hist_prices=raw_data)

# Build the features
feature_builder.build_features()
features_data = feature_builder.get_features()

# Inspect the features
display_df(features_data, rows=5)

# State Representation

In [None]:
esl = EpisodeStateLoader(features_data, 'configs/episodes.json')

In [None]:
esl.get_state_matrix('train', 1, 'AAPL', 2, 4)

In [None]:
esl.get_state_OHLCV('train', 1, 'AAPL', 1)

# Deep Q-Network and RL Agent

In [None]:
# Init the RL agent and DQN Model
window_size = 26
agent = RLAgent(window_size=window_size, num_features=(3, 8))

In [None]:
training_config = {
    'batch_size': 256,
    'val_group_size': 5
    'model_dir': 'model_checkpoints/',
    'plots_dir': 'plots/validation/'
}

In [None]:
episode_list = np.arange(1, 8)
agent.train(esl, [1], 64, 5)