# 05_Regime_Testing

Evaluate the trained RL agent's robustness across different market regimes (bull, bear, high volatility, sideways).


In [None]:
from src.agents.ppo_agent import PPOAgent
from src.agents.evaluation import evaluate_model
from src.environment.multi_asset_env import MultiAsset21DeepHedgingEnv
from src.environment.option_pricing import create_synthetic_option_chain
from src.utils.data_utils import download_market_data
from src.environment.market_data import MarketDataHandler
from src.config.settings import get_config
import matplotlib.pyplot as plt

# Prepare spot and option chain
cfg = get_config('data')
df = download_market_data(**cfg)
option_chain = create_synthetic_option_chain(df, get_config('option'))

strikes = get_config('option')['strike_offsets']
expiries = get_config('option')['expiry_days']
types_ = get_config('option')['option_types']
asset_universe = [{'strike_offset': s, 'expiry_days': e, 'type': t}
                  for e in expiries for s in strikes for t in types_]

handler = MarketDataHandler()
regime_data = handler.get_regime_data(df)


In [None]:
metrics_by_regime = []

for regime_name, df_reg in regime_data.items():
    opt_chain_reg = create_synthetic_option_chain(df_reg, get_config('option'))
    env = MultiAsset21DeepHedgingEnv(df_reg, opt_chain_reg, asset_universe)
    agent = PPOAgent(env)
    model = agent.create_model()
    model.learn(total_timesteps=15000)
    metrics = evaluate_model(model, env, episodes=10)
    metrics.update({'regime': regime_name})
    metrics_by_regime.append(metrics)
    print(f"{regime_name} | Sharpe: {metrics['sharpe_ratio']:.2f}, Mean reward: {metrics['mean_reward']:.2f}")


In [None]:
import pandas as pd
regime_summary = pd.DataFrame(metrics_by_regime)
display(regime_summary)


## 5.2 Visualize Regime Performance


In [None]:
regime_summary.set_index('regime')['sharpe_ratio'].plot(kind='bar', title='Sharpe Ratio by Market Regime')
plt.ylabel('Sharpe Ratio')
plt.show()

regime_summary.set_index('regime')['mean_reward'].plot(kind='bar', title='Mean Reward by Market Regime')
plt.ylabel('Mean Reward per Episode')
plt.show()
