In [None]:
import os
# For Colab/Google Drive integration:
from google.colab import drive
drive.mount('/content/drive')
os.chdir('/content/drive/MyDrive/FinRL/final')  # Change to your project folder in Drive

Mounted at /content/drive


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from typing import List, Tuple
from collections import Counter
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')  # For non-GUI environments like Colab
import json

# ==================== IMPROVED EVALUATION PARAMETERS ====================
# Data selection for evaluation - Choose which dataset to evaluate on
EVAL_DATA_SPLIT = "valid"  # Options: "train", "valid", "test"

# Feature toggles
NORMALIZE_LLM_SIGNALS = True  # min-max normalize sentiment and risk

REWARD_SCALE = 1e-1
ENTROPY_COEF = 0.05
TARGET_KL = 0.02

AGENT_DIR_NAME = "PPO"  # Subdirectory where PPO agents are saved

# ==================== NETWORK ARCHITECTURE ====================

EVAL_NET_DIMS = [256, 256, 128, 128, 64] 
EVAL_STATE_DIM = 12  # 2 (position, holding) + 8 (Alpha101) + 2 (LLM signals)
EVAL_ACTION_DIM = 3  # Short, Hold, Long

# Environment
EVAL_NUM_SIMS = 1  # single env during evaluation
EVAL_MAX_POSITION = 1
EVAL_STEP_GAP = 2
EVAL_SLIPPAGE = 7e-7
EVAL_NUM_IGNORE_STEP = 60

# Evaluation settings
EVAL_STARTING_CASH = 1e6
EVAL_THRESH = 0.001

# Device
GPU_ID = 0

# Versioning (training model directory version prefix)
PPO_VERSION = "5_7_3"

# ==================== PPO NETWORK ARCHITECTURE ====================

class SimplePPOActor(nn.Module):
    """Actor network for PPO"""

    def __init__(self, state_dim: int, action_dim: int, net_dims: List[int]):
        super().__init__()
        self.state_dim = state_dim
        self.action_dim = action_dim

        # Simple MLP layers
        layers = []
        input_dim = state_dim

        for i, dim in enumerate(net_dims):
            layers.extend([
                nn.Linear(input_dim, dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            # Add layer normalization for deeper networks
            if i < len(net_dims) - 1:
                layers.append(nn.LayerNorm(dim))
            input_dim = dim

        # Output layer
        layers.append(nn.Linear(input_dim, action_dim))

        self.net = nn.Sequential(*layers)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=0.01)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state)

    def get_action_probs(self, state: torch.Tensor) -> torch.Tensor:
        """Get action probabilities using softmax"""
        logits = self.forward(state)
        return F.softmax(logits, dim=-1)

    def get_action_log_probs(self, state: torch.Tensor, action: torch.Tensor) -> torch.Tensor:
        """Get log probabilities for specific actions - FIXED shape issues"""
        action_probs = self.get_action_probs(state)
        log_probs = torch.log(action_probs + 1e-8)
        action = action.unsqueeze(-1)
        gathered = log_probs.gather(-1, action).squeeze(-1)

        return gathered

class SimplePPOCritic(nn.Module):
    """Critic network for PPO"""

    def __init__(self, state_dim: int, net_dims: List[int]):
        super().__init__()
        self.state_dim = state_dim

        # Build MLP layers
        layers = []
        input_dim = state_dim

        for i, dim in enumerate(net_dims):
            layers.extend([
                nn.Linear(input_dim, dim),
                nn.ReLU(),
                nn.Dropout(0.1)
            ])
            if i < len(net_dims) - 1:
                layers.append(nn.LayerNorm(dim))
            input_dim = dim

        # Output layer for single value
        layers.append(nn.Linear(input_dim, 1))

        self.net = nn.Sequential(*layers)
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.orthogonal_(module.weight, gain=1.0)
            nn.init.constant_(module.bias, 0.0)

    def forward(self, state: torch.Tensor) -> torch.Tensor:
        return self.net(state).squeeze(-1)

# ==================== PPO AGENT ====================

class ImprovedPPO:
    """PPO agent"""

    def __init__(self, state_dim: int, action_dim: int, net_dims: List[int],
                 device: str = 'cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.state_dim = state_dim
        self.action_dim = action_dim

        # Networks - using simplified architecture
        self.actor = SimplePPOActor(state_dim, action_dim, net_dims).to(device)
        self.critic = SimplePPOCritic(state_dim, net_dims).to(device)

        # Training stats
        self.total_steps = 0
        self.episode_rewards = []
        self.episode_lengths = []

    def select_action(self, state: torch.Tensor, training: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Select action using current policy with improved exploration"""
        with torch.no_grad():
            action_probs = self.actor.get_action_probs(state)

            if training:
                temperature = 1.2  # Slightly higher temperature for more exploration
                logits = self.actor(state)
                scaled_logits = logits / temperature
                scaled_probs = F.softmax(scaled_logits, dim=-1)
                dist = torch.distributions.Categorical(scaled_probs)
                action = dist.sample()
                log_prob = dist.log_prob(action)
            else:
                # Greedy action selection for evaluation
                action = torch.argmax(action_probs, dim=-1)
                log_prob = torch.zeros_like(action, dtype=torch.float)

            value = self.critic(state)

        return action, log_prob, value

    def save_or_load_agent(self, cwd: str, if_save: bool):
        """Save or load agent models"""
        if if_save:
            os.makedirs(cwd, exist_ok=True)
            torch.save(self.actor.state_dict(), os.path.join(cwd, 'actor.pth'))
            torch.save(self.critic.state_dict(), os.path.join(cwd, 'critic.pth'))
            print(f"💾 Saved Improved PPO agent to {cwd}")
        else:
            actor_path = os.path.join(cwd, 'actor.pth')
            critic_path = os.path.join(cwd, 'critic.pth')

            if os.path.exists(actor_path) and os.path.exists(critic_path):
                try:
                    # Try loading as state_dict first
                    self.actor.load_state_dict(torch.load(actor_path, map_location=self.device))
                    self.critic.load_state_dict(torch.load(critic_path, map_location=self.device))
                    print(f"Loaded Improved PPO agent from {cwd}")
                    return True
                except Exception as e:
                    print(f"Error loading Improved PPO agent: {e}")
                    print(f"Debug: Let's inspect the saved model structure...")

                    try:
                        actor_state = torch.load(actor_path, map_location=self.device)
                        critic_state = torch.load(critic_path, map_location=self.device)

                        print(f"Saved Actor keys: {list(actor_state.keys())}")
                        print(f"Saved Critic keys: {list(critic_state.keys())}")
                        print(f"Current Actor keys: {list(self.actor.state_dict().keys())}")
                        print(f"Current Critic keys: {list(self.critic.state_dict().keys())}")

                        # Try loading with strict=False
                        missing_keys, unexpected_keys = self.actor.load_state_dict(actor_state, strict=False)
                        missing_keys2, unexpected_keys2 = self.critic.load_state_dict(critic_state, strict=False)

                        print(f"Actor missing keys: {missing_keys}")
                        print(f"Actor unexpected keys: {unexpected_keys}")
                        print(f"Critic missing keys: {missing_keys2}")
                        print(f"Critic unexpected keys: {unexpected_keys2}")

                        if len(missing_keys) == 0 and len(missing_keys2) == 0:
                            print(f"Loaded Improved PPO agent with strict=False from {cwd}")
                            return True
                        else:
                            print(f"Still missing keys even with strict=False")
                            return False

                    except Exception as e2:
                        print(f"Error during debug loading: {e2}")
                        return False
            else:
                print(f"Model files not found in {cwd}")
                return False

# ==================== TRADING ENVIRONMENT ====================

class TradeSimulator:
    def __init__(self, num_sims=1, slippage=5e-5, max_position=1, step_gap=1,
                 delay_step=1, num_ignore_step=60, device=torch.device("cpu"), gpu_id=-1, timeframe='1sec'):
        self.device = torch.device(f"cuda:{gpu_id}") if gpu_id >= 0 else device
        self.num_sims = num_sims
        self.slippage = slippage
        self.delay_step = delay_step
        self.max_holding = 60 * 60 // step_gap
        self.max_position = max_position
        self.step_gap = step_gap
        self.sim_ids = torch.arange(self.num_sims, device=self.device)
        self.timeframe = timeframe

        self.load_data()

        self.env_name = "TradeSimulator-v0"
        self.state_dim = 2 + 8 + 2
        self.action_dim = 3
        self.if_discrete = True
        self.max_step = (self.seq_len - num_ignore_step) // step_gap
        self.target_return = +np.inf

        self.best_price = torch.zeros((num_sims,), dtype=torch.float32, device=device)
        self.stop_loss_thresh = 1e-3

    def load_data(self):
        # Load data based on evaluation configuration with multiple fallback options
        possible_paths = [
            f"output/{self.timeframe}/{EVAL_DATA_SPLIT}_predictions.npy",
            f"output/{self.timeframe}/train_predictions.npy",
            f"output/{self.timeframe}/predictions.npy",
            f"../output/{self.timeframe}/{EVAL_DATA_SPLIT}_predictions.npy",
            f"../output/{self.timeframe}/train_predictions.npy",
            f"../output/{self.timeframe}/predictions.npy"
        ]

        factor_path = None
        data_type = EVAL_DATA_SPLIT
        for path in possible_paths:
            if os.path.exists(path):
                factor_path = path
                print(f"Found {data_type} data at: {path}")
                break

        if factor_path is None:
            # Quit evaluation if no predictions found
            print("Error: No prediction files found in any of the expected locations:")
            for path in possible_paths:
                print(f"   - {path}")
            print("Please ensure the prediction files exist before running evaluation.")
            raise FileNotFoundError("Required prediction files not found. Cannot proceed with evaluation.")

        self.factor_ary = np.load(factor_path)
        self.factor_ary = torch.tensor(self.factor_ary, dtype=torch.float32)

        print(f"Loaded {data_type} data: {self.factor_ary.shape}")

        # Data leakage warning
        if EVAL_DATA_SPLIT == "train":
            print("WARNING: Evaluating on training data - this will give overly optimistic results!")
            print("Consider setting EVAL_DATA_SPLIT to 'valid' or 'test' for proper evaluation.")

        if self.factor_ary.shape[0] == 0:
            raise ValueError(f"Alpha101 factors are empty for {self.timeframe}.")

        # Try multiple paths for price data
        possible_csv_paths = [
            f"data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train_{self.timeframe}_train_70.csv",
            f"../data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train_{self.timeframe}_train_70.csv",
            f"data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train.csv",
            f"../data/{self.timeframe}/BTC_{self.timeframe}_with_sentiment_risk_train.csv"
        ]

        csv_path = None
        for path in possible_csv_paths:
            if os.path.exists(path):
                csv_path = path
                print(f"Found price data at: {path}")
                break

        if csv_path is None:
            # Quit evaluation if no price data found
            print("Error: No price data files found in any of the expected locations:")
            for path in possible_csv_paths:
                print(f"   - {path}")
            print("Please ensure the price data files exist before running evaluation.")
            raise FileNotFoundError("Required price data files not found. Cannot proceed with evaluation.")
        try:
            data_df = pd.read_csv(csv_path, engine='python')
        except Exception:
            data_df = pd.read_csv(csv_path)

        required_columns = ["bids_distance_3", "asks_distance_3", "midpoint", "sentiment_score", "risk_score"]
        missing_columns = [col for col in required_columns if col not in data_df.columns]
        if missing_columns:
            raise ValueError(f"Missing required columns: {missing_columns}.")

        self.price_ary = data_df[["bids_distance_3", "asks_distance_3", "midpoint"]].values
        self.price_ary[:, 0] = self.price_ary[:, 2] * (1 + self.price_ary[:, 0])
        self.price_ary[:, 1] = self.price_ary[:, 2] * (1 + self.price_ary[:, 0])
        self.llm_signals = data_df[["sentiment_score", "risk_score"]].values

        min_len = min(self.factor_ary.shape[0], self.price_ary.shape[0], self.llm_signals.shape[0])
        self.factor_ary = self.factor_ary[:min_len]
        self.price_ary = torch.tensor(self.price_ary[:min_len], dtype=torch.float32)
        self.llm_signals = torch.tensor(self.llm_signals[:min_len], dtype=torch.float32)

        if NORMALIZE_LLM_SIGNALS:
            llm_min = self.llm_signals.amin(dim=0, keepdim=True)
            llm_max = self.llm_signals.amax(dim=0, keepdim=True)
            llm_range = (llm_max - llm_min)
            self.llm_signals = (self.llm_signals - llm_min) / (llm_range + 1e-6)
            self.llm_signals = torch.clamp(self.llm_signals, 0.0, 1.0)

        if self.timeframe == '1sec':
            self.seq_len = 1800
        elif self.timeframe == '1min':
            self.seq_len = 64
        elif self.timeframe == '5min':
            self.seq_len = 16
        else:
            self.seq_len = 1800

        self.full_seq_len = self.price_ary.shape[0]

    def reset(self, slippage=None):
        self.slippage = slippage if isinstance(slippage, float) else self.slippage

        min_start = self.seq_len
        max_start = self.full_seq_len - self.seq_len * 2
        if min_start >= max_start:
            if self.full_seq_len > self.seq_len:
                max_start = self.full_seq_len - self.seq_len
                min_start = 0
                i0s = np.random.randint(min_start, max_start, size=self.num_sims)
            else:
                if self.full_seq_len > 0:
                    i0s = np.arange(min(self.num_sims, self.full_seq_len))
                    if len(i0s) < self.num_sims:
                        i0s = np.pad(i0s, (0, self.num_sims - len(i0s)), mode='constant', constant_values=0)
                else:
                    i0s = np.zeros(self.num_sims, dtype=np.int64)
        else:
            i0s = np.random.randint(min_start, max_start, size=self.num_sims)

        self.step_i = 0
        self.step_is = torch.tensor(i0s, dtype=torch.long, device=self.device)
        # Initialize with starting cash instead of zero
        starting_cash = 1000000.0  # $1M starting capital
        self.cash = torch.full((self.num_sims,), starting_cash, dtype=torch.float32, device=self.device)
        self.asset = torch.full((self.num_sims,), starting_cash, dtype=torch.float32, device=self.device)

        # Debug logging for asset initialization
        print(f"Debug: Environment reset - Cash: {self.cash[0].item():.2f}, Asset: {self.asset[0].item():.2f}")
        self.holding = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.position = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.empty_count = torch.zeros((self.num_sims,), dtype=torch.long, device=self.device)
        self.best_price = torch.zeros((self.num_sims,), dtype=torch.float32, device=self.device)

        step_is = self.step_is + self.step_i
        state = self.get_state(step_is)
        return state

    def step(self, action):
        try:
            if not hasattr(self, 'step_i'):
                self.step_i = 0
            self.step_i += self.step_gap
            step_is = self.step_is + self.step_i

            if action.dim() == 1:
                action = action.to(self.device)
            else:
                action = action.squeeze().to(self.device)
            action_int = action - 1

            old_cash = self.cash
            old_asset = self.asset
            old_position = self.position

            step_is_cpu = step_is.cpu()
            max_idx = self.price_ary.shape[0] - 1
            step_is_cpu = torch.clamp(step_is_cpu, 0, max_idx)
            mid_price = self.price_ary[step_is_cpu, 2].to(self.device)

            truncated = self.step_i >= (self.max_step * self.step_gap)
            if truncated:
                action_int = -old_position
            else:
                new_position = (old_position + action_int).clip(-self.max_position, self.max_position)
                action_int = new_position - old_position

            self.holding = self.holding + 1
            mask_max_holding = self.holding.gt(self.max_holding)
            if mask_max_holding.sum() > 0:
                action_int[mask_max_holding] = -old_position[mask_max_holding]
            self.holding[old_position == 0] = 0

            direction_mask1 = old_position.gt(0)
            if direction_mask1.sum() > 0:
                _best_price = torch.max(
                    torch.stack([self.best_price[direction_mask1], mid_price[direction_mask1]]),
                    dim=0,
                )[0]
                self.best_price[direction_mask1] = _best_price

            direction_mask2 = old_position.lt(0)
            if direction_mask2.sum() > 0:
                _best_price = torch.min(
                    torch.stack([self.best_price[direction_mask2], mid_price[direction_mask2]]),
                    dim=0,
                )[0]
                self.best_price[direction_mask2] = _best_price

            stop_loss_mask1 = torch.logical_and(direction_mask1, (self.best_price - mid_price).gt(self.stop_loss_thresh))
            stop_loss_mask2 = torch.logical_and(direction_mask2, (mid_price - self.best_price).gt(self.stop_loss_thresh))
            stop_loss_mask = torch.logical_or(stop_loss_mask1, stop_loss_mask2)
            if stop_loss_mask.sum() > 0:
                action_int[stop_loss_mask] = -old_position[stop_loss_mask]

            new_position = old_position + action_int
            direction = action_int.gt(0)
            cost = action_int * mid_price
            new_cash = old_cash - cost * torch.where(direction, 1 + self.slippage, 1 - self.slippage)
            new_asset = new_cash + new_position * mid_price
            reward = new_asset - old_asset

            # Debug logging for asset changes
            if self.step_i % 100 == 0:  # Debug every 100 steps
                print(f"Debug: Step {self.step_i} - Old Asset: {old_asset[0].item():.2f}, "
                      f"New Asset: {new_asset[0].item():.2f}, Reward: {reward[0].item():.6f}")

            self.cash = new_cash
            self.asset = new_asset
            self.position = new_position

            state = self.get_state(step_is)
            if truncated:
                terminal = torch.ones_like(self.position, dtype=torch.bool)
                # Don't reset during evaluation - let caller handle episode termination
                # state = self.reset()  # Commented out to prevent evaluation resets
            else:
                terminal = torch.zeros_like(self.position, dtype=torch.bool)

            return state, reward, terminal, {}

        except Exception as e:
            print(f"Error in environment step: {e}")
            state = torch.zeros((self.num_sims, self.state_dim), dtype=torch.float32).to(self.device)
            reward = torch.zeros((self.num_sims,), dtype=torch.float32).to(self.device)
            terminal = torch.zeros((self.num_sims,), dtype=torch.bool).to(self.device)
            return state, reward, terminal, {}

    def get_state(self, step_is):
        step_is_cpu = step_is.cpu()
        max_idx = min(self.factor_ary.shape[0], self.price_ary.shape[0], self.llm_signals.shape[0]) - 1
        step_is_cpu = torch.clamp(step_is_cpu, 0, max_idx)
        factor_ary = self.factor_ary[step_is_cpu, :].to(self.device)
        llm_signals = self.llm_signals[step_is_cpu, :].to(self.device)
        return torch.concat(
            (
                (self.position.float() / self.max_position)[:, None],
                (self.holding.float() / self.max_holding)[:, None],
                factor_ary,
                llm_signals,
            ),
            dim=1,
        )

# ==================== FINANCIAL METRICS ====================

def cumulative_returns(returns_pct):
    """Calculate cumulative returns with safe handling"""
    if isinstance(returns_pct, np.ndarray):
        return np.cumprod(1 + returns_pct)
    else:
        return (1 + returns_pct).cumprod()

def sharpe_ratio(returns_pct, risk_free=0):
    """Calculate Sharpe ratio with safe handling"""
    returns = np.array(returns_pct)
    if len(returns) == 0:
        return 0.0
    if returns.std() == 0:
        sharpe_ratio = np.inf
    else:
        sharpe_ratio = (returns.mean()-risk_free) / returns.std()
    return sharpe_ratio

def max_drawdown(returns_pct):
    """Calculate max drawdown with safe division"""
    if len(returns_pct) == 0:
        return 0.0

    cumulative = cumulative_returns(returns_pct)
    if isinstance(cumulative, np.ndarray):
        running_max = np.maximum.accumulate(cumulative)
        running_max_safe = np.where(running_max == 0, 1e-8, running_max)
        drawdown = (cumulative - running_max) / running_max_safe
        return drawdown.min()
    else:
        running_max = cumulative.expanding().max()
        running_max_safe = running_max.replace(0, 1e-8)
        drawdown = (cumulative - running_max) / running_max_safe
        return drawdown.min()

def return_over_max_drawdown(returns_pct):
    """Calculate return over max drawdown with safe handling"""
    if len(returns_pct) == 0:
        return 0.0

    mdd = abs(max_drawdown(returns_pct))
    cumulative = cumulative_returns(returns_pct)

    if isinstance(cumulative, np.ndarray):
        total_return = cumulative[-1] - 1
    else:
        total_return = cumulative.iloc[-1] - 1

    return total_return / mdd if mdd > 0 else 0.0

# ==================== IMPROVED ENSEMBLE EVALUATOR ====================

class ImprovedEnsembleEvaluator:
    """Ensemble Evaluator"""

    def __init__(self, save_path: str, agent_classes: List, timeframe: str = '1sec', gpu_id: int = 0):
        self.save_path = save_path
        self.agent_classes = agent_classes
        self.timeframe = timeframe
        self.gpu_id = gpu_id
        self.device = torch.device(f'cuda:{gpu_id}' if torch.cuda.is_available() else 'cpu')

        # Initialize environment with same parameters as training
        self.trade_env = TradeSimulator(
            num_sims=EVAL_NUM_SIMS,
            slippage=EVAL_SLIPPAGE,
            max_position=EVAL_MAX_POSITION,
            step_gap=EVAL_STEP_GAP,
            gpu_id=gpu_id,
            timeframe=timeframe
        )

        # Debug environment info
        print(f"Environment initialized:")
        print(f"   - max_step: {self.trade_env.max_step}")
        print(f"   - seq_len: {self.trade_env.seq_len}")
        print(f"   - full_seq_len: {self.trade_env.full_seq_len}")
        print(f"   - state_dim: {self.trade_env.state_dim}")
        print(f"   - action_dim: {self.trade_env.action_dim}")

        # Initialize agents
        self.agents = []
        self.agent_paths = []

        # Set starting cash
        self.starting_cash = EVAL_STARTING_CASH

        # Results storage
        self.net_assets = []
        self.positions = []
        self.cash = []
        self.btc_positions = []
        self.midpoints = []
        self.cum_returns = []

    def load_agents(self):
        """Load trained agents from save_path"""
        print(f"Loading agents from: {self.save_path}")

        # For PPO, we expect actor.pth and critic.pth files directly in save_path
        actor_path = os.path.join(self.save_path, "actor.pth")
        critic_path = os.path.join(self.save_path, "critic.pth")

        if os.path.exists(actor_path) and os.path.exists(critic_path):
            # Create agent instance with correct state dimension from environment
            agent = ImprovedPPO(
                state_dim=self.trade_env.state_dim,
                action_dim=self.trade_env.action_dim,
                net_dims=EVAL_NET_DIMS,
                device=self.device
            )

            # Load trained weights directly from save_path
            if agent.save_or_load_agent(self.save_path, if_save=False):
                # Check if agent is already loaded to prevent duplicates
                if not any(agent_path == self.save_path for agent_path in self.agent_paths):
                    self.agents.append(agent)
                    self.agent_paths.append(self.save_path)
                    print(f"Loaded PPO agent from {self.save_path}")
                else:
                    print(f"PPO agent already loaded from {self.save_path}, skipping duplicate")
            else:
                print(f"Failed to load PPO agent from {self.save_path}")
        else:
            print(f"PPO model files not found in {self.save_path}")
            print(f"Expected: {actor_path}")
            print(f"Expected: {critic_path}")
            if os.path.exists(self.save_path):
                print(f"  Available files in {self.save_path}:")
                for item in os.listdir(self.save_path):
                    item_path = os.path.join(self.save_path, item)
                    if os.path.isdir(item_path):
                        print(f"{item}/")
                    else:
                        print(f"{item}")

        if not self.agents:
            raise ValueError("No agents loaded successfully!")

        print(f"Loaded {len(self.agents)} agents successfully")
        print(f"Agent details:")
        for i, agent in enumerate(self.agents):
            print(f"Agent {i}: {type(agent).__name__}")
            print(f"State dim: {agent.state_dim}")
            print(f"Action dim: {agent.action_dim}")

    def multi_trade(self) -> dict:
        """Run improved multi-agent trading simulation"""
        print(f"Starting multi-agent trading simulation...")

        # Initialize environment
        initial_state = self.trade_env.reset()

        # Initialize arrays for tracking
        num_steps = self.trade_env.max_step
        num_agents = len(self.agents)

        # Debug information
        print(f"Debug: num_steps={num_steps}, num_agents={num_agents}")

        # Ensure num_steps is positive
        if num_steps <= 0:
            num_steps = 1000  # Fallback value
            print(f"Warning: max_step was {self.trade_env.max_step}, using fallback {num_steps}")

        # Ensure num_agents is positive
        if num_agents <= 0:
            print(f"Error: No agents available for evaluation!")
            raise ValueError(f"Expected at least 1 agent, but got {num_agents}")

        print(f"Debug: Creating arrays with shape ({num_steps}, {num_agents})")
        self.net_assets = np.zeros((num_steps, num_agents))
        self.positions = np.zeros((num_steps, num_agents))
        self.cash = np.zeros((num_steps, num_agents))
        self.btc_positions = np.zeros((num_steps, num_agents))
        self.midpoints = []

        print(f"Debug: Array shapes - net_assets: {self.net_assets.shape}, positions: {self.positions.shape}")

        # Initialize starting values
        for i in range(num_agents):
            self.net_assets[0, i] = self.starting_cash
            self.cash[0, i] = self.starting_cash
            self.btc_positions[0, i] = 0

        #Debug logging for initial setup
        print(f"🔍 Debug: Initial setup complete. Starting cash: ${self.starting_cash}")
        print(f"🔍 Debug: Net assets[0]: {self.net_assets[0]}")

        # Trading loop
        step = 0

        try:
            while step < num_steps - 1:
                # Get current state (first simulation)
                state = self.trade_env.get_state(self.trade_env.step_is)
                state_single = state[0:1]  # Take first simulation for single agent evaluation

                # Collect actions from all agents
                actions = []
                for agent in self.agents:
                    action, _, _ = agent.select_action(state_single, training=False)
                    actions.append(action.item())

                # Majority voting for ensemble decision
                action_counts = Counter(actions)
                ensemble_action = action_counts.most_common(1)[0][0]

                # Convert action from [0,1,2] to [1,2,3] for environment
                action_env = torch.tensor([ensemble_action + 1], dtype=torch.long, device=self.trade_env.device).expand(self.trade_env.num_sims)

                # Execute ensemble action
                next_state, reward, terminal, info = self.trade_env.step(action_env)

                #Apply reward scaling and shaping
                base_reward = reward[0].item() * REWARD_SCALE

                # Reward shaping to encourage trading
                current_position = self.trade_env.position[0].item()
                prev_position = getattr(self.trade_env, '_prev_position', 0)
                position_change = abs(current_position - prev_position)
                trading_bonus = 0.001 * position_change  # Small bonus for taking positions

                # Profit incentive
                profit_incentive = 0.01 * base_reward if base_reward > 0 else 0

                shaped_reward = base_reward + trading_bonus + profit_incentive

                # Store previous position for next step
                self.trade_env._prev_position = current_position

                # Check terminal condition BEFORE updating arrays
                # This prevents using reset values from the environment
                if terminal[0].item():
                    print(f"🔍 Debug: Episode terminated at step {step}")
                    break

                # Update tracking arrays - only if episode hasn't terminated
                if step < num_steps - 1:
                    for i in range(num_agents):
                        # Update net assets (use first simulation)
                        current_asset = self.trade_env.asset[0].item()
                        self.net_assets[step + 1, i] = current_asset

                        # Update positions (use first simulation)
                        self.positions[step + 1, i] = self.trade_env.position[0].item()

                        # Update cash and BTC positions (use first simulation)
                        current_price = self.trade_env.price_ary[self.trade_env.step_is[0] + self.trade_env.step_i, 2].item()
                        self.cash[step + 1, i] = self.trade_env.cash[0].item()
                        self.btc_positions[step + 1, i] = self.trade_env.position[0].item() * current_price

                        # Debug: Log unusual asset values
                        if step % 100 == 0:  # Only log every 100 steps to avoid spam
                            print(f"🔍 Debug: Step {step}, Agent {i}: Asset=${current_asset:.2f}")

                    # Store midpoint price (use first simulation)
                    current_price = self.trade_env.price_ary[self.trade_env.step_is[0] + self.trade_env.step_i, 2].item()
                    self.midpoints.append(current_price)

                    # Prediction accuracy tracking (was only used for win/loss rates)

                # Move to next step
                step += 1
                state = next_state

        except Exception as e:
            print(f"Warning: Trading loop stopped at step {step}: {e}")
            # Truncate arrays to actual steps completed
            actual_steps = max(1, step + 1)
            print(f"Truncating arrays to {actual_steps} steps")

            if actual_steps > 0 and self.net_assets.shape[1] > 0:
                self.net_assets = self.net_assets[:actual_steps, :]
                self.positions = self.positions[:actual_steps, :]
                self.cash = self.cash[:actual_steps, :]
                self.btc_positions = self.btc_positions[:actual_steps, :]
                print(f"Debug: Arrays truncated to shapes - net_assets: {self.net_assets.shape}")
            else:
                print(f"Error: Invalid array shapes after truncation")
                raise ValueError("Array truncation failed - invalid shapes")

        # Debug: Check array shapes after trading loop
        print(f"Debug: After trading loop - Array shapes:")
        print(f"   net_assets: {self.net_assets.shape}")
        print(f"   positions: {self.positions.shape}")
        print(f"   cash: {self.cash.shape}")
        print(f"   btc_positions: {self.btc_positions.shape}")
        print(f"   midpoints length: {len(self.midpoints)}")

        # Debug: Check final asset values
        print(f"Debug: Final net_assets[-1]: {self.net_assets[-1] if len(self.net_assets) > 0 else 'No data'}")
        print(f"Debug: Final cash[-1]: {self.cash[-1] if len(self.cash) > 0 else 'No data'}")
        print(f"Debug: Final btc_positions[-1]: {self.btc_positions[-1] if len(self.btc_positions) > 0 else 'No data'}")

        # Calculate cumulative returns
        self.cum_returns = self.net_assets / self.starting_cash
        print(f"Debug: cum_returns shape: {self.cum_returns.shape}")
        print(f"Debug: Final cum_returns[-1]: {self.cum_returns[-1] if len(self.cum_returns) > 0 else 'No data'}")

        # Save results
        results_dir = f"{self.save_path}_evaluation_results"
        os.makedirs(results_dir, exist_ok=True)

        # Save arrays
        np.save(f"{results_dir}/net_assets.npy", self.net_assets)
        np.save(f"{results_dir}/positions.npy", self.positions)
        np.save(f"{results_dir}/cash.npy", self.cash)
        np.save(f"{results_dir}/btc_positions.npy", self.btc_positions)
        np.save(f"{results_dir}/midpoints.npy", np.array(self.midpoints))
        np.save(f"{results_dir}/cum_returns.npy", self.cum_returns)

        # Calculate and return metrics
        return self.calculate_metrics(results_dir)

    def calculate_metrics(self, results_dir: str) -> dict:
        """Calculate evaluation metrics with improved safety"""
        # Debug: Check array shapes at start of calculate_metrics
        print(f"   Debug: In calculate_metrics - Array shapes:")
        print(f"   net_assets: {self.net_assets.shape}")
        print(f"   positions: {self.positions.shape}")
        print(f"   cash: {self.cash.shape}")
        print(f"   btc_positions: {self.btc_positions.shape}")

        # Check if we have enough data
        if len(self.net_assets) < 2:
            print("Error: Not enough data points for metrics calculation")
            return None

        # Compute metrics with improved safety - MATCHING TASK1_EVAL.PY
        print(f"Debug: Before np.diff - net_assets shape: {self.net_assets.shape}")

        # Safe array operations with explicit copies
        net_assets_diff = np.diff(self.net_assets, axis=0).copy()
        net_assets_prev = self.net_assets[:-1].copy()

        print(f"Debug: net_assets_diff shape: {net_assets_diff.shape}")
        print(f"Debug: net_assets_prev shape: {net_assets_prev.shape}")

        # Safe division with explicit shapes - handle division by zero
        net_assets_prev_safe = np.where(net_assets_prev == 0, 1e-8, net_assets_prev)
        returns = net_assets_diff / net_assets_prev_safe

        # Handle any remaining invalid values (inf, nan)
        returns = np.nan_to_num(returns, nan=0.0, posinf=0.0, neginf=0.0)

        print(f"Debug: After safe division - returns shape: {returns.shape}")
        print(f"Debug: Returns range: min={returns.min():.6f}, max={returns.max():.6f}")

        # Safety check: Verify array shapes are valid
        if returns.shape[0] <= 0 or returns.shape[1] <= 0:
            print(f"Error: Invalid returns array shape: {returns.shape}")
            return None

        # Check if returns array is valid
        if len(returns) == 0:
            print("Error: No returns data available for metrics calculation")
            return None

        final_sharpe_ratio = sharpe_ratio(returns)
        final_max_drawdown = max_drawdown(returns)
        final_roma = return_over_max_drawdown(returns)

        # Win rate and loss rate calculations (not focusing on these metrics)

        # Extract scalar values for calculations
        final_net_assets = float(self.net_assets[-1, 0])
        total_return = (final_net_assets - self.starting_cash) / self.starting_cash

        print(f"\n{'='*60}")
        print(f"EVALUATION RESULTS - {self.timeframe.upper()}")
        print(f"{'='*60}")
        print(f" Starting Cash: ${self.starting_cash:,.2f}")
        print(f" Final Net Assets: ${final_net_assets:,.2f}")
        print(f" Total Return: {total_return:.4f} ({total_return*100:.2f}%)")
        print(f" Sharpe Ratio: {final_sharpe_ratio:.4f}")
        print(f" Max Drawdown: {final_max_drawdown:.4f}")
        print(f" Return over Max Drawdown: {final_roma:.4f}")
        print(f" Ensemble Size: {len(self.agents)} agents")
        print(f" Results saved to: {results_dir}")
        print(f"{'='*60}")

        # Save metrics summary
        metrics_summary = {
            'timeframe': self.timeframe,
            'starting_cash': float(self.starting_cash),
            'final_net_assets': float(self.net_assets[-1, 0]),
            'total_return': float(total_return),
            'sharpe_ratio': float(final_sharpe_ratio),
            'max_drawdown': float(final_max_drawdown),
            'return_over_max_drawdown': float(final_roma),
            'ensemble_size': int(len(self.agents)),
            'agent_classes': [agent.__class__.__name__ for agent in self.agents]
        }

        with open(f"{results_dir}/metrics_summary.json", 'w') as f:
            json.dump(metrics_summary, f, indent=2)

        return metrics_summary

    def plot_results(self, results_dir: str):
        """Plot evaluation results"""
        try:
            net_assets = np.load(f"{results_dir}/net_assets.npy")
            positions = np.load(f"{results_dir}/positions.npy")
            cash = np.load(f"{results_dir}/cash.npy")
            btc_positions = np.load(f"{results_dir}/btc_positions.npy")

            # Debug: Check array shapes after loading
            print(f"   Debug: Plotting - Array shapes after loading:")
            print(f"   net_assets: {net_assets.shape}")
            print(f"   positions: {positions.shape}")
            print(f"   cash: {cash.shape}")
            print(f"   btc_positions: {btc_positions.shape}")

            # Optional arrays
            midpoints = None
            cum_returns = None
            midpoints_path = f"{results_dir}/midpoints.npy"
            cum_returns_path = f"{results_dir}/cum_returns.npy"
            if os.path.exists(midpoints_path):
                midpoints = np.load(midpoints_path)
            if os.path.exists(cum_returns_path):
                cum_returns = np.load(cum_returns_path)

            fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 10))

            # Net Assets over time
            ax1.plot(net_assets, label='Net Assets', color='blue')
            ax1.set_title('Net Assets Over Time')
            ax1.set_xlabel('Steps')
            ax1.set_ylabel('Net Assets ($)')
            ax1.grid(True)
            ax1.legend()

            # Cash and BTC positions
            ax2.plot(cash, label='Cash', color='green', alpha=0.7)
            ax2.plot(btc_positions, label='BTC Value', color='orange', alpha=0.7)
            ax2.set_title('Cash vs BTC Positions')
            ax2.set_xlabel('Steps')
            ax2.set_ylabel('Value ($)')
            ax2.grid(True)
            ax2.legend()

            # Trading positions
            ax3.plot(positions, label='Position', color='red')
            ax3.set_title('Trading Positions')
            ax3.set_xlabel('Steps')
            ax3.set_ylabel('Position')
            ax3.grid(True)
            ax3.legend()

            # Returns distribution
            if net_assets.shape[1] > 0:  # Check if array has valid shape
                net_assets_diff = np.diff(net_assets, axis=0).copy()
                net_assets_prev = net_assets[:-1].copy()
                # Safe division
                net_assets_prev_safe = np.where(net_assets_prev == 0, 1e-8, net_assets_prev)
                returns = net_assets_diff / net_assets_prev_safe
                returns = np.nan_to_num(returns, nan=0.0, posinf=0.0, neginf=0.0)

                ax4.hist(returns.flatten(), bins=50, alpha=0.7, color='purple')
                ax4.set_title('Returns Distribution')
                ax4.set_xlabel('Returns')
                ax4.set_ylabel('Frequency')
                ax4.grid(True)
            else:
                ax4.text(0.5, 0.5, 'No valid returns data', ha='center', va='center', transform=ax4.transAxes)
                ax4.set_title('Returns Distribution')
                ax4.set_xlabel('Returns')
                ax4.set_ylabel('Frequency')

            plt.tight_layout()
            plt.savefig(f"{results_dir}/evaluation_plots.png", dpi=300, bbox_inches='tight')
            plt.close()

            print(f"Plots saved to: {results_dir}/evaluation_plots.png")

            # Additional plot: Cumulative returns vs Midpoint over timesteps
            print(f"🔍 Debug: Plotting cumulative returns vs midpoint:")
            print(f"   cum_returns shape: {cum_returns.shape if cum_returns is not None else 'None'}")
            print(f"   midpoints length: {len(midpoints) if midpoints is not None else 'None'}")

            if cum_returns is None:
                cum_returns = net_assets / net_assets[0]
                print(f"   Created cum_returns with shape: {cum_returns.shape}")

            # Align lengths: cum_returns includes initial point; midpoints collected per step
            if midpoints is not None and len(midpoints) > 0:
                n = min(len(midpoints), max(0, len(cum_returns) - 1))
                print(f"   Aligned length n: {n}")

                if n > 0:
                    x = np.arange(n)
                    cr_plot = cum_returns[1:1 + n]
                    mp_plot = midpoints[:n]

                    print(f"   cr_plot shape: {cr_plot.shape}, mp_plot length: {len(mp_plot)}")

                    fig, ax1 = plt.subplots(figsize=(12, 5))
                    color1 = 'tab:blue'
                    ax1.set_xlabel('Timesteps')
                    ax1.set_ylabel('Cumulative Return (x)', color=color1)

                    # Handle single data point case
                    if n == 1:
                        # For single point, plot as scatter points to make them visible
                        ax1.scatter(x, cr_plot, color=color1, s=100, label='Cumulative Return', zorder=5)
                        ax1.axhline(y=cr_plot[0], color=color1, alpha=0.3, linestyle='--')
                        print(f"   Single data point detected - using scatter plot for visibility")
                    else:
                        # For multiple points, use line plot
                        ax1.plot(x, cr_plot, color=color1, label='Cumulative Return')

                    ax1.tick_params(axis='y', labelcolor=color1)
                    ax1.grid(True, which='both', axis='both', alpha=0.3)

                    ax2 = ax1.twinx()
                    color2 = 'tab:orange'
                    ax2.set_ylabel('Midpoint Price', color=color2)

                    # Handle single data point case for midpoint
                    if n == 1:
                        ax2.scatter(x, mp_plot, color=color2, s=100, alpha=0.8, label='Midpoint', zorder=5)
                        ax2.axhline(y=mp_plot[0], color=color2, alpha=0.3, linestyle='--')
                    else:
                        ax2.plot(x, mp_plot, color=color2, alpha=0.6, label='Midpoint')

                    ax2.tick_params(axis='y', labelcolor=color2)

                    plt.title(f'Cumulative Returns and Midpoint over Timesteps - {self.timeframe}')
                    plt.tight_layout()
                    plt.savefig(f"{results_dir}/cum_returns_vs_midpoint.png", dpi=300, bbox_inches='tight')
                    plt.close()
                    print(f"Saved: {results_dir}/cum_returns_vs_midpoint.png")
                else:
                    print(f"Warning: No valid data points for cumulative returns plot (n={n})")
            else:
                print(f"   Warning: No midpoints data available for cumulative returns plot")
                print(f"   midpoints: {midpoints}")

            # Add note about limited data for higher timeframes
            if len(midpoints) <= 5:
                print(f"   Note: Very limited data points ({len(midpoints)}) for {self.timeframe} timeframe")
                print(f"   This suggests the trading environment may need configuration adjustments")
                print(f"   Consider: increasing seq_len, checking data availability, or using lower timeframes")

        except Exception as e:
            print(f"  Failed to create plots: {str(e)}")

# ==================== MAIN EVALUATION FUNCTIONS ====================

def evaluate_improved_ensemble(save_path: str, agent_classes: List, timeframe: str = '1sec', gpu_id: int = 0):
    """Main evaluation function"""
    print(f"  Starting PPO Ensemble Evaluation - {timeframe}")
    print("=" * 60)

    # Check if trained agents exist
    if not os.path.exists(save_path):
        print(f"  Error: Trained agents directory not found: {save_path}")
        print(f"  Please ensure you have trained PPO agents.")
        return None

    try:
        # Initialize evaluator
        evaluator = ImprovedEnsembleEvaluator(
            save_path=save_path,
            agent_classes=agent_classes,
            timeframe=timeframe,
            gpu_id=gpu_id
        )

        # Load agents
        evaluator.load_agents()

        # Run evaluation
        metrics = evaluator.multi_trade()

        # Create plots after evaluation
        if metrics:
            results_dir = f"{save_path}_evaluation_results"
            evaluator.plot_results(results_dir)

        return metrics

    except Exception as e:
        print(f"Evaluation failed: {str(e)}")
        return None

In [None]:
# Default parameters
save_path = "./trained_agents"
agent_classes = [ImprovedPPO]
gpu_id = -1
timeframe = "1sec"

print(f" Improved Ensemble Evaluation Script")
print(f" Save Path: {save_path}")
print(f" Timeframe: {timeframe}")
print(f"  GPU ID: {gpu_id}")
print(f" Agents: {[cls.__name__ for cls in agent_classes]}")
print(f" Data Split: {EVAL_DATA_SPLIT.upper()}")
print(f" LLM Normalization: {'Enabled' if NORMALIZE_LLM_SIGNALS else 'Disabled'}")
print(f" Expected Agent Path: {save_path}/{PPO_VERSION}/{timeframe}/{AGENT_DIR_NAME}")

# Check if agents exist for specific timeframe under versioned directory
timeframe_save_path = f"{save_path}/{PPO_VERSION}/{timeframe}/{AGENT_DIR_NAME}"
if not os.path.exists(timeframe_save_path):
    print(f"   No trained agents found for {timeframe}")
    print(f"   Expected path: {timeframe_save_path}")
    print(f"   Please train agents first using the improved training script!")
else:
    evaluate_improved_ensemble(timeframe_save_path, agent_classes, timeframe, gpu_id)


🎯 Improved Ensemble Evaluation Script
📁 Save Path: ./trained_agents
⏰ Timeframe: 1sec
🖥️  GPU ID: -1
🤖 Agents: ['ImprovedPPO']
📊 Data Split: VALID
🔧 LLM Normalization: ✅ Enabled
🎯 Expected Agent Path: ./trained_agents/5_7_3/1sec/PPO
🚀 Starting IMPROVED PPO Ensemble Evaluation - 1sec
✅ Found valid data at: output/4_1/1sec/valid_predictions.npy
✅ Loaded valid data: torch.Size([74212, 8])
✅ Found price data at: data/1sec/BTC_1sec_with_sentiment_risk_train_1sec_train_70.csv
🔍 Environment initialized:
   - max_step: 870
   - seq_len: 1800
   - full_seq_len: 74212
   - state_dim: 12
   - action_dim: 3
🔍 Loading agents from: ./trained_agents/5_7_3/1sec/PPO
✅ Loaded Improved PPO agent from ./trained_agents/5_7_3/1sec/PPO
✅ Loaded Improved PPO agent from ./trained_agents/5_7_3/1sec/PPO
🤖 Loaded 1 agents successfully
🔍 Agent details:
   Agent 0: ImprovedPPO
   State dim: 12
   Action dim: 3
🚀 Starting improved multi-agent trading simulation...
🔍 Debug: Environment reset - Cash: 1000000.00, Asset