W O R K O U T

In [None]:
# Full pretrain + fine-tune pipeline

import os
import torch
import pandas as pd
import importlib
from torch.utils.data import DataLoader
from RL1_ES import (init_trading_agent, generate_labels,
                    RLPretrainDataset, pretrain_macro, pretrain_vol, pretrain_price,
                    unfreeze_all)
from training_module import run_training
from RL1 import load_rl_data, TradingEnv

# === User-configurable parameters ===
ticker = 'ES'
rl_folder = 'RL Data'
episodes_per_run = 500
checkpoint_dir = 'checkpoints'
batch_size = 32
pretrain_epochs = 10

# Ensure checkpoint directory exists
os.makedirs(checkpoint_dir, exist_ok=True)
ckpt_path = os.path.join(checkpoint_dir, f'{ticker.lower()}_agent.pth')

# === STEP 1: Load Data and Generate Labels ===
df = pd.read_csv(os.path.join(rl_folder, f"RL - {ticker}.csv"))
df = generate_labels(df)

feature_cols = ['Macro', 'Vol', 'Security', 'Security Proba', 'AVWAP', 'Base', 'R1', 'R2', 'S1', 'S2']

# === STEP 2: Build Pretraining Datasets ===
macro_dataset = RLPretrainDataset(df, feature_cols, label_type='direction')
vol_dataset   = RLPretrainDataset(df, feature_cols, label_type='size')
price_dataset = RLPretrainDataset(df, feature_cols, label_type='stop')

macro_loader = DataLoader(macro_dataset, batch_size=batch_size, shuffle=True)
vol_loader   = DataLoader(vol_dataset, batch_size=batch_size, shuffle=True)
price_loader = DataLoader(price_dataset, batch_size=batch_size, shuffle=True)

# === STEP 3: Initialize Agent ===
agent = init_trading_agent(input_size=len(feature_cols))

# === STEP 4: Pretrain Each Branch ===
print("Starting Macro Pretraining...")
pretrain_macro(agent, macro_loader, epochs=pretrain_epochs)

print("Starting Vol Pretraining...")
pretrain_vol(agent, vol_loader, epochs=pretrain_epochs)

print("Starting Price Pretraining...")
pretrain_price(agent, price_loader, epochs=pretrain_epochs)

# === STEP 5: Unfreeze Agent for RL Fine-tuning ===
unfreeze_all(agent)

# === STEP 6: Load Checkpoint if available ===
if os.path.exists(ckpt_path):
    print(f"Loading checkpoint from {ckpt_path}")
    agent.load_state_dict(torch.load(ckpt_path))
else:
    print("No checkpoint found, starting from pretrained weights.")

# === STEP 7: Create TradingEnv ===
df_env = load_rl_data(rl_folder, ticker)
env = TradingEnv(df_env)

# === STEP 8: Run RL Training ===
results = run_training(env, agent, num_episodes=episodes_per_run)

# === STEP 9: Save updated checkpoint ===
torch.save(agent.state_dict(), ckpt_path)
print(f"Checkpoint saved to {ckpt_path}")

# === STEP 10: Save results ===
results.to_csv('results.csv')


Loading checkpoint from checkpoints\es_agent.pth
Checkpoint saved to checkpoints\es_agent.pth


In [5]:
results.tail()

Unnamed: 0,total_return,sharpe,max_drawdown,episode,trades
495,-0.00286,-6.162969,-286.0,495,1
496,-0.00574,-17.236217,-574.0,496,1
497,-0.0075,-95.793241,-750.0,497,1
498,0.00502,13.23145,-78.0,498,1
499,0.0039,9.052214,-86.0,499,1
