# Training and Comparing RL Agents (PPO vs. DQN)

This notebook provides a complete, end-to-end example of using the `rl_trading_project` framework with a robust data pipeline. We will perform the following steps:

1.  **Data Ingestion:** Create a high-performance DuckDB database from raw, gzipped CSV files. This simulates a real-world scenario where you have large amounts of historical data.
2.  **Data Loading & Preparation:** Load the prepared data into a Pandas DataFrame, ready for our RL environments.
3.  **Train a PPO Agent:** Train a Proximal Policy Optimization (PPO) agent to manage a **multi-asset portfolio**. PPO is ideal for this task due to its ability to handle continuous, multi-dimensional action spaces.
4.  **Train a DQN Agent:** Train a Dueling Deep Q-Network (DQN) agent on a **single-asset** task. We do this to highlight the strengths of DQN in discrete action spaces and show how different agents are suited to different problems.
5.  **Backtesting & Comparison:** Evaluate both trained agents on out-of-sample data and compare their performance metrics and equity curves.

## Part 1: Data Ingestion with DuckDB

First, we need data. We'll start by ingesting the `RAW_DIR` directory filled with monthly `csv.gz` files for multiple assets using our `ingest_raw_data_to_duckdb` function to build a persistent, columnar database named `market_data.duckdb`.

In [None]:
import os
import gzip
import pandas as pd
import numpy as np
import duckdb
import shutil

from rl_trading_project.data.duckdb_loader import ingest_raw_data_to_duckdb

RAW_DIR = '../AlphaVantage Data/raw'
DB_PATH = 'market_data.duckdb'

ingest_summary = ingest_raw_data_to_duckdb(raw_dir=RAW_DIR, db_path=DB_PATH, source_timezone='UTC')
print("\n--- Ingestion Summary ---")
print(ingest_summary)


## Part 2: Load Data and Set Up Environments

With our database created, we can now easily query the data we need for our training and testing periods. We'll load all the data and preprocess to make sure all assets have enough history to create a portfolio. 2/3rd is used for training and 1/3rd for backtest.

In [None]:
import matplotlib.pyplot as plt

# Connect to the database and load all data into pandas
con = duckdb.connect(DB_PATH, read_only=True)
portfolio_df = con.execute("SELECT * FROM ohlcv ORDER BY timestamp, asset").fetchdf()
con.close()

# Convert timestamp to timezone-aware and set the multi-index required by PortfolioEnv
portfolio_df['timestamp'] = pd.to_datetime(portfolio_df['timestamp']).dt.tz_convert('UTC')
portfolio_df = portfolio_df.set_index(['timestamp', 'asset'])

print(f"Loaded {len(portfolio_df)} total rows for assets: {portfolio_df.index.get_level_values('asset').unique().tolist()}")
print("Data Head:")
print(portfolio_df.head())

# Visualize the data
fig, ax = plt.subplots(figsize=(15, 7))
for asset in portfolio_df.index.get_level_values('asset').unique():
    asset_prices = portfolio_df.xs(asset, level='asset')['close']
    ax.plot(asset_prices, label=asset)
ax.set_title('Loaded Asset Price History')
ax.set_xlabel('Time')
ax.set_ylabel('Price')
ax.legend()
ax.grid(True)
plt.show()

# --- Data Filtering and Splitting ---
# In a real-world scenario, you first filter for assets with sufficient history.
# We will demonstrate this by filtering for assets with at least 23 months of data.
MIN_MONTHS = 23
month_counts = portfolio_df.reset_index().groupby('asset')['timestamp'].apply(lambda x: (x.max() - x.min()).days / 30.44)
assets_with_enough_data = month_counts[month_counts >= MIN_MONTHS].index.tolist()

print(f"Assets with >= {MIN_MONTHS} months of data: {assets_with_enough_data}")

# Filter the main DataFrame to only include these assets
portfolio_df = portfolio_df[portfolio_df.index.get_level_values('asset').isin(assets_with_enough_data)]

# CRITICAL: Check if any assets remain after filtering before proceeding.
if portfolio_df.empty:
    raise ValueError("No assets met the minimum data requirement. Cannot proceed with training.")

# Get unique timestamps for the filtered set of assets
unique_timestamps = portfolio_df.index.get_level_values('timestamp').unique().sort_values()

# Calculate the split point (2/3 for training, 1/3 for testing)
split_index = int(len(unique_timestamps) * (2/3))
split_date = unique_timestamps[split_index]

print(f"Total unique timestamps: {len(unique_timestamps)}")
print(f"Calculated split date for 2/3 train, 1/3 test: {split_date}")

# Split the DataFrame based on the calculated timestamp
train_df = portfolio_df[portfolio_df.index.get_level_values('timestamp') < split_date]
test_df = portfolio_df[portfolio_df.index.get_level_values('timestamp') >= split_date]

print(f"\nTraining data from {train_df.index.get_level_values('timestamp').min()} to {train_df.index.get_level_values('timestamp').max()}")
print(f"Testing data from {test_df.index.get_level_values('timestamp').min()} to {test_df.index.get_level_values('timestamp').max()}")

## Part 3: Train the PPO Agent for Multi-Asset Portfolio Management

PPO is perfectly suited for our `PortfolioEnv` because its action space is a continuous vector representing the target allocation for each asset. We will train it on our multi-assset portfolio.

In [None]:
from rl_trading_project.agents import PPOAgent
from rl_trading_project.envs import PortfolioEnv, GymWrapper
import torch

# --- PPO Setup ---
SEED = 42
WINDOW_SIZE = 30

# Create the multi-asset environment with the training data
ppo_env = PortfolioEnv(
    df=train_df,
    window_size=WINDOW_SIZE,
    initial_balance=100_000,
    max_leverage=2.0,
    commission=0.0005, # 5 bps
    reward_type='risk_adjusted',
    drawdown_penalty=100.0
)
ppo_wrapped_env = GymWrapper(ppo_env)

# Instantiate the PPO agent
ppo_agent = PPOAgent(
    obs_dim=ppo_wrapped_env.observation_space.shape[0],
    action_dim=ppo_wrapped_env.action_space.shape[0],
    lr=3e-4, 
    epochs=10,
    minibatch_size=128,
    seed=SEED
)

# --- PPO Training Loop ---
TRAIN_STEPS = 50000 
ROLLOUT_LEN = 512

print("Starting PPO training...")
obs, _ = ppo_wrapped_env.reset(start_index=WINDOW_SIZE)
trajectories = []
total_steps_done = 0

while total_steps_done < TRAIN_STEPS:
    # Rollout Phase
    for _ in range(ROLLOUT_LEN):
        action, logp, value = ppo_agent.act(obs, deterministic=False)
        next_obs, reward, terminated, truncated, info = ppo_wrapped_env.step(action)
        done = terminated or truncated
        trajectories.append({'obs': obs, 'act': action, 'rew': reward, 'done': done, 'logp': logp, 'value': value})
        obs = next_obs
        total_steps_done += 1
        if done: 
            obs, _ = ppo_wrapped_env.reset(start_index=WINDOW_SIZE)
        if total_steps_done >= TRAIN_STEPS: break

    # Update Phase
    stats = ppo_agent.update(trajectories)
    trajectories.clear()
    if total_steps_done % (ROLLOUT_LEN * 10) == 0: # Print less frequently
        print(f"Step: {min(total_steps_done, TRAIN_STEPS)}/{TRAIN_STEPS}, Policy Loss: {stats['policy_loss']:.4f}, Value Loss: {stats['value_loss']:.4f}")

print("\nPPO Training finished!")

## Part 4: Train the DQN Agent for Single-Asset Trading

Our `DuelingDQNAgent` is a value-based agent designed for problems with a **discrete action space**. It outputs a Q-value for each possible action (e.g., full short, half short, hold, half long, full long). This is fundamentally different from PPO, which can output a continuous action vector.

Therefore, we cannot directly apply our DQN agent to the multi-asset `PortfolioEnv`. Instead, we will train it on a simplified, **single-asset** version of the environment using only the `QQQ` data. This is a common and powerful use case for DQN-style agents.

In [None]:
from rl_trading_project.agents import DuelingDQNAgent
from rl_trading_project.envs import TradingEnv # Using SimpleEnv for clarity, but a single-asset PortfolioEnv also works

# --- DQN Setup ---
# Filter the training data for our single asset
qqq_train_df = train_df.xs('QQQ', level='asset').reset_index()

# Create a single-asset environment
dqn_env = TradingEnv(
    df=qqq_train_df,
    window_size=WINDOW_SIZE,
    initial_balance=100_000,
    max_position=100.0, # Max position in units of the asset
    commission=0.0005
)
dqn_wrapped_env = GymWrapper(dqn_env)

# Instantiate the DQN agent
dqn_agent = DuelingDQNAgent(
    obs_dim=dqn_wrapped_env.observation_space.shape[0],
    action_bins=11, # Discretize action into 11 bins from -1 (full short) to +1 (full long)
    lr=5e-4,
    buffer_size=100_000,
    batch_size=128,
    seed=SEED
)

# --- DQN Training Loop ---
print("\nStarting DQN training...")
obs, _ = dqn_wrapped_env.reset(start_index=WINDOW_SIZE)

for step in range(TRAIN_STEPS):
    action = dqn_agent.act(obs, deterministic=False)
    next_obs, reward, terminated, truncated, info = dqn_wrapped_env.step(action)
    done = terminated or truncated
    
    dqn_agent.add_experience(obs, action, reward, next_obs, done)
    stats = dqn_agent.update(sync_freq=200)
    
    obs = next_obs
    if done:
        obs, _ = dqn_wrapped_env.reset(start_index=WINDOW_SIZE)
        
    if (step + 1) % 2000 == 0:
         print(f"Step {step+1}/{TRAIN_STEPS}, Loss={stats.get('loss'):.4f}, Epsilon={stats.get('eps'):.2f}")

print("DQN Training finished!")

## Part 5: Backtesting and Comparison

Now for the moment of truth. We will use the `Backtester` module to run both of our trained agents on the out-of-sample test data (the 3rd month). For the evaluation, we always use `deterministic=True` to make the agent exploit its learned policy without random exploration.

In [None]:
from rl_trading_project.trainers import Backtester, compare_strategies, reporting

# --- 1. PPO Backtest (Multi-Asset) ---
def ppo_policy_fn(obs, t):
    action, _, _ = ppo_agent.act(obs, deterministic=True)
    return action

ppo_test_env_factory = lambda: GymWrapper(PortfolioEnv(df=test_df, window_size=WINDOW_SIZE))

print("Running PPO backtest...")
ppo_backtester = Backtester(ppo_test_env_factory, start_index=WINDOW_SIZE)
ppo_results = ppo_backtester.run(ppo_policy_fn, max_steps=len(test_df.index.get_level_values('timestamp').unique()) - WINDOW_SIZE - 1)

# --- 2. DQN Backtest (Single-Asset on STCK) ---
def dqn_policy_fn(obs, t):
    action = dqn_agent.act(obs, deterministic=True)
    return action

qqq_test_df = test_df.xs('QQQ', level='asset').reset_index()
dqn_test_env_factory = lambda: GymWrapper(SimpleEnv(df=qqq_test_df, window_size=WINDOW_SIZE))

print("Running DQN backtest...")
dqn_backtester = Backtester(dqn_test_env_factory, start_index=WINDOW_SIZE)
dqn_results = dqn_backtester.run(dqn_policy_fn, max_steps=len(qqq_test_df) - WINDOW_SIZE - 1)

# --- 3. Comparison ---
comparison = compare_strategies({
    'PPO_MultiAsset': ppo_results,
    'DQN_SingleAsset': dqn_results
})

print("\n--- Backtest Comparison Summary ---")
summary_df = pd.DataFrame(comparison).T[['total_return', 'sharpe_ratio', 'max_drawdown', 'end_value']]
summary_df['total_return'] = summary_df['total_return'].apply(lambda x: f"{x:.2%}")
summary_df['max_drawdown'] = summary_df['max_drawdown'].apply(lambda x: f"{x:.2%}")
print(summary_df)

# --- 4. Plotting Equity Curves ---
ppo_history_df = pd.DataFrame(ppo_results['history'])
dqn_history_df = pd.DataFrame(dqn_results['history'])

fig, ax = plt.subplots(figsize=(15, 8))
ax.plot(ppo_history_df['total_value'], label='PPO (Multi-Asset Portfolio)', lw=2)
ax.plot(dqn_history_df['total_value'], label='DQN (Single-Asset: STCK)', lw=2)
ax.set_title('Agent Equity Curves (Out-of-Sample)')
ax.set_xlabel('Test Steps')
ax.set_ylabel('Portfolio Value ($)')
ax.legend()
ax.grid(True)
plt.show()

## Conclusion

This notebook demonstrated the full workflow: from raw data ingestion to training multiple, distinct RL agents and comparing their performance. 

The **PPO agent** successfully learned a policy to manage a portfolio of correlated assets, leveraging its ability to output continuous allocation vectors. 

The **DQN agent**, while not suitable for the multi-asset task in its current form, proved effective for a single-asset trading problem with a discretized set of actions (buy/sell/hold decisions). 

This highlights a key principle in applied RL: choosing the right agent architecture for the specific problem and action space is critical for success.