In [None]:
import numpy as np
from typing import Dict, List, Tuple, Optional
import gymnasium as gym
from gymnasium import spaces
import pandas as pd
from dataclasses import dataclass
from enum import Enum


class Actions(Enum):
    NO_POSITION = 0
    LONG = 1
    SHORT = 2


class MarketSession(Enum):
    TOKYO = 0
    LONDON = 1
    NEW_YORK = 2
    OFF_HOURS = 3


@dataclass
class Position:
    """Represents an open trading position."""
    type: str  # 'long' or 'short'
    entry_price: float
    size: float
    entry_time: pd.Timestamp
    base_currency: str
    quote_currency: str
    take_profit: Optional[float] = None
    stop_loss: Optional[float] = None

@dataclass
class RewardParams:
    """Parameters controlling the reward function behavior."""
    realized_pnl_weight: float = 0.1
    unrealized_pnl_weight: float = 0.8
    holding_time_threshold: int = 7*12  # hours ok
    holding_penalty_factor: float = -0.00001
    max_trades_per_day: int = 6 
    overtrading_penalty_factor: float = -0.0001
    win_rate_threshold: float = 0.4
    win_rate_bonus_factor: float = 0.0005
    drawdown_penalty_factor: float = -0.0001

class ForexTradingEnv(gym.Env):
    def __init__(
        self,
        df: pd.DataFrame,
        pair: str,
        initial_balance: float = 1_000_000.0,
        trade_size: float = 100_000.0,
        max_position_size: float = 1.0,
        transaction_cost: float = 0.0001,
        reward_scaling: float = 1e-4,
        sequence_length: int = 10,
        random_start: bool = True,
        margin_rate_pct:float = 0.01,
        trading_history_size: int = 50,  # Keep track of last 50 trades
        reward_params: Optional[RewardParams] = None,
    ):
        super(ForexTradingEnv, self).__init__()

        self.df = df
        self.trade_size = trade_size
        self.pair = pair
        self.base_currency = pair.split('_')[0]
        self.quote_currency = pair.split('_')[1]
        self.initial_balance = initial_balance
        self.balance = self.initial_balance
        self.max_position_size = max_position_size
        self.transaction_cost = transaction_cost
        self.reward_scaling = reward_scaling
        self.sequence_length = sequence_length
        self.random_start = random_start
        self.margin_rate_pct = margin_rate_pct
        self._last_trade_info = None

        # Initialize reward parameters
        self.reward_params = reward_params or RewardParams()

        # Additional tracking for enhanced observations
        self.trading_history_size = trading_history_size
        self.trade_history = []  # List of past trade results
        self.session_trades = {session: [] for session in MarketSession}
        self.peak_balance = initial_balance
        self.session_start_balance = initial_balance

        # Calculate observation space size including account state
        self.feature_columns = [col for col in df.columns
                                if col not in ['timestamp', 'volume']]
        # Enhanced observation space
        self.market_features = len(self.feature_columns)
        # Basic account features (balance, position type, size)
        self.account_features = 7
        # Time in pos, drawdown, dist to SL/TP, ATR ratio, unrealized PnL
        self.risk_features = 5
        # Hour sin/cos, day sin/cos, session one-hot (3)
        self.context_features = 7
        # Win ratio, avg PnL, drawdown, trade count, session success
        self.history_features = 5

        # Define action space (NO_POSITION, LONG, SHORT)
        self.action_space = spaces.Discrete(len(Actions))

        self.observation_space = spaces.Dict({
            'market': spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(sequence_length, self.market_features),
                dtype=np.float32
            ),
            'account': spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(self.account_features,),
                dtype=np.float32
            ),
            'risk': spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(self.risk_features,),
                dtype=np.float32
            ),
            'context': spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(self.context_features,),
                dtype=np.float32
            ),
            'history': spaces.Box(
                low=-np.inf,
                high=np.inf,
                shape=(self.history_features,),
                dtype=np.float32
            )
        })

        # Initialize state
        self.reset()

    def reset(self, seed: Optional[int] = None) -> Tuple[Dict[str, np.ndarray], Dict]:
        """Reset the environment to initial state."""
        super().reset(seed=seed)

        self.balance = self.initial_balance
        self.position: Optional[Position] = None
        self.current_step = self.sequence_length
        self.trade_history = [] 

        if self.random_start and len(self.df) > self.sequence_length + 100:
            self.current_step = np.random.randint(
                self.sequence_length,
                len(self.df) - 100
            )

        self.total_pnl = 0.0
        self.total_trades = 0
        self.winning_trades = 0
        self.trade_history = []

        return self._get_observation(), self._get_info()

    def _print_after_episode(self):
        """Print episode summary with corrected metrics."""
        total_return = ((self.balance / self.initial_balance) - 1) * 100
        win_rate = (self.winning_trades / max(1, self.total_trades)) * 100
        
        print("\nEpisode Summary:")
        print(f"Final Return: {total_return:.2f}%")
        print(f"Total PnL: {self.total_pnl:.2f}")
        print(f"Total Trades: {self.total_trades}")
        print(f"Winning Trades: {self.winning_trades}")
        print(f"Win Rate: {win_rate:.2f}%")
        print(f"Initial Balance: {self.initial_balance:.2f}")
        print(f"Final Balance: {self.balance:.2f}")
        print("-" * 50)
        pass 

    def step(self, action: int) -> Tuple[Dict[str, np.ndarray], float, bool, bool, Dict]:
        """Execute one step in the environment."""
        action = Actions(action)
        reward = 0.0
        # Move to next step / get next price
        current_price = self.df.iloc[self.current_step]['close']
        self.current_step += 1
        prev_price = self.df.iloc[self.current_step-1]['close']
        if self.balance == 0 or self.initial_balance == 0:
            print(f"0 Value balance: {self.balance} self.initial_balance: {self.initial_balance} at step: {self.current_step}")

        # Handle position transitions
        if action == Actions.NO_POSITION and self.position is not None:
            # Close current position
            reward = self._calculate_reward(self._close_position(current_price))

        elif action == Actions.LONG:
            if self.position is None:
                # Open long position
                self._open_position('long', current_price)
                reward = self._calculate_reward()
            
            elif self.position.type == 'short':
                # Close short and open long
                reward = self._calculate_reward(self._close_position(current_price))
                self._open_position('long', current_price)

            elif self.position.type == 'long':
                # Maintain long position, calculate reward based on holding
                reward = self._calculate_reward()

        elif action == Actions.SHORT:
            if self.position is None:
                # Open short position
                self._open_position('short', current_price)
                reward = self._calculate_reward()
            
            elif self.position.type == 'long':
                # Close long and open short
                reward = self._calculate_reward(self._close_position(current_price))
                self._open_position('short', current_price)
            
            elif self.position.type == 'short':
                # Maintain short position, calculate reward based on holding
                reward = self._calculate_reward()

       


        # Check if episode is done
        terminated = self.current_step >= len(self.df) - 1 or self.balance <= 0
        truncated = False
        if terminated or truncated:
            self._print_after_episode()

        return self._get_observation(), reward, terminated, truncated, self._get_info()
    
    def _get_market_sequence(self) -> np.ndarray:
        """Get the market data sequence with padding if needed."""
        if self.current_step >= len(self.df):
            raise IndexError("Current step exceeds dataset length")
        start_idx = self.current_step - self.sequence_length
        end_idx = self.current_step
        
        # Handle the case where we don't have enough history
        if start_idx < 0:
            # Create padding
            pad_length = abs(start_idx)
            market_data = self.df.iloc[0:end_idx][self.feature_columns].values
            padding = np.zeros((pad_length, len(self.feature_columns)))
            market_obs = np.vstack([padding, market_data])
        else:
            market_obs = self.df.iloc[start_idx:end_idx][self.feature_columns].values
            
        return market_obs

    def _get_account_state(self) -> np.ndarray:
        """Calculate the current account state features."""
        # Initialize with zeros
        position_type = 0.0  # No position
        position_size = 0.0
        unrealized_pnl = 0.0
        
        # Update if position exists
        if self.position is not None:
            # Position type: 1 for long, -1 for short
            position_type = 1.0 if self.position.type == 'long' else -1.0
            
            # Normalized position size
            position_size = self.position.size / self.initial_balance
            
            # Calculate unrealized PnL
            current_price = self.df.iloc[self.current_step]['close']
            unrealized_pnl = self._calculate_pnl(
                self.position.type,
                self.position.entry_price,
                current_price,
                self.position.size
            ) / self.initial_balance  # Normalize by initial balance
        
        return np.array([
            self.balance / self.initial_balance,  # Normalized balance
            position_type,  # Position direction
            position_size,  # Normalized position size
            unrealized_pnl,  # Normalized unrealized PnL
            self.total_pnl / self.initial_balance,  # Normalized total PnL
            self.total_trades / 1000.0,  # Normalized trade count (assuming max 1000 trades)
            self.winning_trades / max(1, self.total_trades)  # Win rate
        ])

    def _get_observation(self) -> Dict[str, np.ndarray]:
        """Construct enhanced observation with additional features."""
        current_time = self.df.index[self.current_step]
        current_price = self.df.iloc[self.current_step]['close']

        # 1. Market data sequence (with padding if needed)
        market_obs = self._get_market_sequence()

        # 2. Account state
        account_obs = self._get_account_state()

        # 3. Risk metrics
        risk_obs = self._get_risk_metrics(current_price)

        # 4. Market context
        context_obs = self._get_market_context(current_time)

        # 5. Trading history
        history_obs = self._get_trading_history()

        return {
            'market': market_obs.astype(np.float32),
            'account': account_obs.astype(np.float32),
            'risk': risk_obs.astype(np.float32),
            'context': context_obs.astype(np.float32),
            'history': history_obs.astype(np.float32)
        }

    def _get_risk_metrics(self, current_price: float) -> np.ndarray:
        """Calculate risk-related metrics."""
        if self.position is None:
            return np.array([0.0, 0.0, 0.0, 0.0, 0.0])

        # Time in position (normalized by typical holding period, e.g., 24 hours)
        time_in_pos = (self.df.index[self.current_step] -
                       self.position.entry_time).total_seconds() / (24 * 3600)

        # Current drawdown from peak balance
        drawdown = (self.peak_balance - self.balance) / self.peak_balance

        # Distance to stop loss/take profit (if set)
        if self.position.stop_loss:
            dist_to_sl = abs(
                current_price - self.position.stop_loss) / current_price
        else:
            dist_to_sl = 1.0  # No stop loss set

        # ATR ratio to position size
        atr = self.df.iloc[self.current_step]['atr']
        if self.balance > 0:
            atr_ratio = atr * self.position.size / self.balance
        else:
            atr_ratio = 0.0

        # Unrealized PnL (normalized by position size)
        unrealized_pnl = self._calculate_pnl(
            self.position.type,
            self.position.entry_price,
            current_price,
            self.position.size
        ) / self.position.size

        return np.array([
            time_in_pos,
            drawdown,
            dist_to_sl,
            atr_ratio,
            unrealized_pnl
        ])

    def _get_market_context(self, current_time: pd.Timestamp) -> np.ndarray:
        """Calculate market context features."""
        # Hour encoding (sin/cos for cyclical nature)
        hour = current_time.hour + current_time.minute / 60.0
        hour_sin = np.sin(2 * np.pi * hour / 24.0)
        hour_cos = np.cos(2 * np.pi * hour / 24.0)

        # Day of week encoding
        day = current_time.weekday()
        day_sin = np.sin(2 * np.pi * day / 7.0)
        day_cos = np.cos(2 * np.pi * day / 7.0)

        # Market session one-hot encoding
        session = self._get_market_session(current_time)
        session_encoding = np.zeros(3)  # Tokyo, London, NY
        if session != MarketSession.OFF_HOURS:
            session_encoding[session.value] = 1.0

        return np.concatenate([
            [hour_sin, hour_cos, day_sin, day_cos],
            session_encoding
        ])

    def _get_trading_history(self) -> np.ndarray:
        """Calculate trading history metrics."""
        if not self.trade_history:
            return np.zeros(5)

        recent_trades = self.trade_history[-self.trading_history_size:]

        # Overall win ratio - check PnL field in trade dictionaries
        win_ratio = sum(1 for t in recent_trades if t['pnl'] > 0) / len(recent_trades)

        # Average PnL
        avg_pnl = np.mean([t['pnl'] for t in recent_trades]) / self.initial_balance

        # Maximum drawdown in current session
        session_drawdown = (self.session_start_balance - 
                            self.balance) / self.session_start_balance

        # Number of trades in current session (normalized)
        current_session = self._get_market_session(
            self.df.index[self.current_step])
        # Normalize by expected max trades per session
        session_trade_count = len(self.session_trades[current_session]) / 20.0

        # Success rate in current session type
        session_trades = self.session_trades[current_session]
        if session_trades:
            session_success = sum(1 for t in session_trades 
                                if t['pnl'] > 0) / len(session_trades)
        else:
            session_success = 0.0

        return np.array([
            win_ratio,
            avg_pnl,
            session_drawdown,
            session_trade_count,
            session_success
        ], dtype=np.float32)

    def _get_market_session(self, timestamp: pd.Timestamp) -> MarketSession:
        """Determine current market session."""
        hour = timestamp.hour

        # Convert to UTC+9 for Tokyo
        tokyo_hour = (hour + 9) % 24
        if 9 <= tokyo_hour < 15:
            return MarketSession.TOKYO

        # London session (UTC+0)
        if 8 <= hour < 16:
            return MarketSession.LONDON

        # New York session (UTC-4)
        ny_hour = (hour - 4) % 24
        if 8 <= ny_hour < 17:
            return MarketSession.NEW_YORK

        return MarketSession.OFF_HOURS

    def _on_trade_closed(self, pnl: float) -> None:
        """Update trade history when a position is closed."""
        if self.position is None:
            return
            
        current_time = self.df.index[self.current_step]
        current_price = self.df.iloc[self.current_step]['close']
        
        trade_info = {
            'pnl': pnl,
            'type': self.position.type,
            'entry_price': self.position.entry_price,
            'exit_price': current_price,
            'trade_closed': True,
            'size': self.position.size,
            'entry_time': self.position.entry_time,
            'exit_time': current_time,
            'duration': (current_time - self.position.entry_time).total_seconds() / 3600,
            'session': self._get_market_session(current_time),
  
        }
        
        self.trade_history.append(trade_info)
        if len(self.trade_history) > self.trading_history_size:
            self.trade_history.pop(0)

        current_session = self._get_market_session(current_time)
        self.session_trades[current_session].append(trade_info)

        # Update peak balance
        self.peak_balance = max(self.peak_balance, self.balance)

    def _calculate_pnl(
        self,
        position_type: str,
        entry_price: float,
        exit_price: float,
        position_size: float
    ) -> float:
        """
        Calculate PnL in base currency terms.

        For example:
        - EUR/USD: PnL in EUR
        - USD/JPY: PnL in USD
        """
        if position_type == 'long':
            # Convert PnL to base currency
            if self.quote_currency == 'USD':
                # For pairs like EUR/USD, convert USD PnL to base currency (EUR)
                pnl = (exit_price - entry_price) * position_size / exit_price
            else:
                # For pairs like USD/JPY, PnL is already in base currency (USD)
                pnl = (exit_price - entry_price) * position_size
        else:  # short
            if self.quote_currency == 'USD':
                pnl = (entry_price - exit_price) * position_size / exit_price
            else:
                pnl = (entry_price - exit_price) * position_size

        return pnl

    def _open_position(self, position_type: str, current_price: float) -> None:
        """Open a new position."""
        position_size = self.balance * self.max_position_size
        entry_price = current_price

        # Add transaction costs
        if position_type == 'long':
            entry_price += self.transaction_cost
        else:
            entry_price -= self.transaction_cost

        self.position = Position(
            type=position_type,
            entry_price=entry_price,
            size=self.trade_size,
            entry_time=self.df.index[self.current_step],
            base_currency=self.base_currency,
            quote_currency=self.quote_currency
        )

        required_margin = self.trade_size * self.margin_rate_pct  # 1% margin requirement
      

    def _close_position(self, current_price: float) -> float:
        """Close current position and return reward."""
        if not self.position:
            return 0.0

        # Calculate PnL with transaction costs
        exit_price = current_price
        if self.position.type == 'long':
            exit_price -= self.transaction_cost
        else:
            exit_price += self.transaction_cost

        pnl = self._calculate_pnl(
            self.position.type,
            self.position.entry_price,
            exit_price,
            self.position.size
        )

        self._last_trade_info = {
            'trade_closed': True,  # Must be True to trigger trade recording
            'trade_pnl': pnl,
            'entry_time': self.position.entry_time,
            'exit_time': self.df.index[self.current_step],
            'entry_price': self.position.entry_price,
            'exit_price': exit_price,
            'position_type': self.position.type,
            'position_size': self.position.size,
            'market_state': {
                'session': self._get_market_session(self.df.index[self.current_step]).name,
                'balance': self.balance,
                'total_trades': self.total_trades,
                'win_rate': self.winning_trades / max(1, self.total_trades)
            }
        }

        # Update metrics
        self.total_pnl += pnl
        self.balance += pnl
        self.total_trades += 1
        if pnl > 0:
            self.winning_trades += 1

        # Call _on_trade_closed before clearing position
        self._on_trade_closed(pnl)  
        # Clear position
        self.position = None

        return pnl * self.reward_scaling
    

 
    def _calculate_reward(self, realized_pnl: float = 0.0) -> float:
        """
        Calculate reward based on multiple factors:
        1. Realized PnL from closed trades
        2. Unrealized PnL from open positions
        3. Risk-adjusted returns (Sharpe-like ratio)
        4. Position holding costs
        5. Trade efficiency metrics
        
        Returns:
            float: Calculated reward
        """
        reward = 0.0
        current_price = self.df.iloc[self.current_step]['close']
        
        # 1. Realized PnL component
        if realized_pnl != 0:
            normalized_pnl = realized_pnl / self.trade_size
            reward += normalized_pnl * (1 + (self.reward_params.realized_pnl_weight if realized_pnl > 0 else 0))
            
      
            
            # Calculate win rate bonus
            # if self.total_trades > 0:
            #     win_rate = self.winning_trades / self.total_trades
            #     reward += win_rate * 0.1  # Small bonus for maintaining good win rate
        
        # 2. Unrealized PnL component for open positions
        if self.position is not None:
            unrealized_pnl = self._calculate_pnl(
                self.position.type,
                self.position.entry_price,
                current_price,
                self.position.size
            )
            
            normalized_unrealized = unrealized_pnl / self.trade_size
        
            # Add scaled unrealized PnL (smaller weight than realized)
            reward += normalized_unrealized * self.reward_params.unrealized_pnl_weight
            
            # Add holding cost penalty (larger for longer-held positions)
            holding_hours = (self.df.index[self.current_step] - 
                        self.position.entry_time).total_seconds() / 3600  # in hours
            # Stronger penalty for very long holds
   
            if holding_hours > self.reward_params.holding_time_threshold:  # Penalize holds over x hours
                holding_penalty = self.reward_params.holding_penalty_factor * (holding_hours - self.reward_params.holding_time_threshold)
                reward += holding_penalty
        
            # 3. Anti-overtrading penalty
            if self.total_trades > 0:
                # Calculate trades per day
                total_days = (self.df.index[self.current_step] - 
                            self.df.index[0]).total_seconds() / (24 * 3600)
                trades_per_day = self.total_trades / max(1, total_days)
                
                # Penalty for excessive trading (more than 6 trades per day)
                if trades_per_day > self.reward_params.max_trades_per_day:
                    overtrading_penalty = self.reward_params.overtrading_penalty_factor * (trades_per_day - self.reward_params.max_trades_per_day)
                    reward += overtrading_penalty
   
            # 4. Win rate Linear increase in bonus above 40% win rate
            min_trades_required = 10
            if self.total_trades >= min_trades_required:
                win_rate = self.winning_trades / self.total_trades
                # Linear scaling between 40% and 60% win rate
                win_rate_bonus = max(0, (win_rate - self.reward_params.win_rate_threshold) * self.reward_params.win_rate_bonus_factor)
                reward += win_rate_bonus
                
            # 5. Risk management penalty (progressive with drawdown)
            if self.balance < self.initial_balance:
                drawdown_pct = (self.initial_balance - self.balance) / self.initial_balance
                # Linear penalty that increases with drawdown
                risk_penalty = self.reward_params.drawdown_penalty_factor * (drawdown_pct * 100) ** 2
                reward += risk_penalty

        return float(reward)
    
    def _get_info(self) -> Dict:
        """Get current state information and performance metrics."""
        current_price = self.df.iloc[self.current_step]['close']
        
        # Calculate unrealized PnL if position exists
        unrealized_pnl = 0.0
        position_duration = 0
        position_type = 'none'
        
        if self.position is not None:
            position_type = self.position.type
            unrealized_pnl = self._calculate_pnl(
                self.position.type,
                self.position.entry_price,
                current_price,
                self.position.size
            )
            position_duration = (self.df.index[self.current_step] - 
                            self.position.entry_time).total_seconds() / 3600  # Convert to hours
        
        # Calculate drawdown
        peak_balance = max(self.peak_balance, self.balance + unrealized_pnl)
        current_balance = self.balance + unrealized_pnl
        drawdown = (peak_balance - current_balance) / peak_balance if peak_balance > 0 else 0.0
        info= {
            # Account metrics
            'balance': self.balance,
            'total_pnl': self.total_pnl,
            'unrealized_pnl': unrealized_pnl,
            'total_trades': self.total_trades,
            'trade_count': self.total_trades,
            'win_rate': self.winning_trades / max(1, self.total_trades),
            'drawdown': drawdown,
            
            # Position info
            'position_type': position_type,
            'position_size': self.position.size if self.position else 0.0,
            'position_duration': position_duration,
            
            # Trading costs and metrics
            'trading_costs': self.transaction_cost * (self.position.size if self.position else 0.0),
            'avg_trade_pnl': self.total_pnl / max(1, self.total_trades),
            
            # Episode progress
            'current_step': self.current_step,
            'total_steps': len(self.df),
            'timestamp': self.df.index[self.current_step],
            
            # Market info
            'current_price': current_price,
            'spread': self.df.iloc[self.current_step].get('spread', self.transaction_cost)
        }
        if self._last_trade_info is not None:
            info.update(self._last_trade_info)
            self._last_trade_info = None
        return info
    
    @property
    def win_rate(self) -> float:
        """Calculate win rate."""
        return self.winning_trades / max(1, self.total_trades)

    @property
    def avg_trade_duration(self) -> float:
        """Calculate average trade duration."""
        if not self.trade_history:
            return 0.0
        return sum(t['duration'] for t in self.trade_history) / len(self.trade_history)

    @property
    def max_drawdown(self) -> float:
        """Calculate maximum drawdown."""
        if self.peak_balance <= 0:
            return 0.0
        return (self.peak_balance - self.balance) / self.peak_balance

    @property
    def position_ratios(self) -> Dict[str, float]:
        """Calculate position type ratios."""
        if not self.trade_history:
            return {'long': 0.0, 'short': 0.0, 'none': 1.0}
        
        total = len(self.trade_history)
        longs = sum(1 for t in self.trade_history if t.get('type') == 'long')
        shorts = sum(1 for t in self.trade_history if t.get('type') == 'short')
        
        return {
            'long': longs / total,
            'short': shorts / total,
            'none': (total - longs - shorts) / total
        }


In [None]:
df = pd.read_parquet('./EUR_USD.parquet')

def split_dataset(
        df: pd.DataFrame, 
        train_ratio: float = 0.7,
        val_ratio: float = 0.15,
        test_ratio: float = 0.15,
        shuffle: bool = False
    ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """
        Split dataset into train, validation and test sets.
        
        Args:
            df: Input DataFrame
            train_ratio: Proportion for training (default: 0.7)
            val_ratio: Proportion for validation (default: 0.15)
            test_ratio: Proportion for testing (default: 0.15)
            shuffle: Whether to shuffle before splitting (default: False for time series)
        
        Returns:
            Tuple of (train_df, val_df, test_df)
        """
        assert np.isclose(train_ratio + val_ratio + test_ratio, 1.0), "Ratios must sum to 1"
        
        n = len(df)
        indices = np.arange(n)
        
        if shuffle:
            np.random.shuffle(indices)
        
        train_idx = int(n * train_ratio)
        val_idx = int(n * (train_ratio + val_ratio))
        
        train_df = df.iloc[indices[:train_idx]]
        val_df = df.iloc[indices[train_idx:val_idx]]
        test_df = df.iloc[indices[val_idx:]]
        
        # Sort by index again if shuffled
        if shuffle:
            train_df = train_df.sort_index()
            val_df = val_df.sort_index()
            test_df = test_df.sort_index()
        
        print(f"Dataset split sizes:")
        print(f"Training: {len(train_df)} samples ({len(train_df)/n:.1%})")
        print(f"Validation: {len(val_df)} samples ({len(val_df)/n:.1%})")
        print(f"Test: {len(test_df)} samples ({len(test_df)/n:.1%})")
        
        return train_df, val_df, test_df

In [None]:
import optuna
from optuna.samplers import TPESampler
from optuna.pruners import MedianPruner
import numpy as np
import pandas as pd
from typing import Dict
from dataclasses import dataclass
import sqlite3
from datetime import datetime

from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecNormalize, DummyVecEnv
from stable_baselines3.common.callbacks import EvalCallback
from stable_baselines3.common.monitor import Monitor
import os, sys

from datetime import datetime, timedelta
from pathlib import Path

import logging

# Configure logging
log_file = "optuna_trials.log"  # Path to log file
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(message)s",
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()  # For console output
    ]
)



@dataclass
class RewardParams:
    """Parameters controlling the reward function behavior."""
    realized_pnl_weight: float = 1.1
    unrealized_pnl_weight: float = 0.8
    holding_time_threshold: int = 7*12  # hours
    holding_penalty_factor: float = -0.00001
    max_trades_per_day: int = 6 
    overtrading_penalty_factor: float = -0.0001
    win_rate_threshold: float = 0.4
    win_rate_bonus_factor: float = 0.0005
    drawdown_penalty_factor: float = -0.0001

@dataclass
class OptimizationResult:
    """Stores results of a single trial."""
    trial_number: int
    params: Dict
    final_balance: float
    total_trades: int
    win_rate: float
    max_drawdown: float
    training_time: float

class RewardOptimizer:
    def __init__(
        self,
        train_df: pd.DataFrame,
        val_df: pd.DataFrame,
        study_name: str = "forex_reward_optimization1",
        n_timesteps: int = 500_000
    ):
        self.train_df = train_df
        self.val_df = val_df
        self.study_name = study_name
        self.n_timesteps = n_timesteps
        
        # Setup study with TPE sampler and Median pruner
        self.study = optuna.create_study(
            study_name=study_name,
            storage="sqlite:///optuna_trials.db",
            load_if_exists=True,
            sampler=TPESampler(seed=42),
            pruner=MedianPruner(
                n_startup_trials=5,  # Wait for at least 5 trials to complete before pruning
                n_warmup_steps=50_000,  # Let each trial run for 50k timesteps before pruning
                interval_steps=50_000
            ),
            direction="maximize"
        )

    def _create_env(self, df: pd.DataFrame, params: Dict, is_eval: bool = False) -> VecNormalize:
        """Create vectorized and normalized environment."""
        def make_env():
            def _init():
                env = ForexTradingEnv(
                    df=df.copy(),
                    pair='EUR_USD',
                    initial_balance=1_000_000,
                    trade_size=100_000,
                    reward_params=RewardParams(**params)
                )
                return Monitor(env)
            return _init

        vec_env = DummyVecEnv([make_env()])
        env = VecNormalize(
            vec_env,
            norm_obs=True,
            norm_reward=not is_eval,
            clip_obs=10.,
            clip_reward=10.,
            gamma=0.99,
            epsilon=1e-08
        )

        return env

    def objective(self, trial: optuna.Trial) -> float:
        """Optimization objective function."""
        try:
            # Sample parameters
            params = self._sample_parameters(trial)
            
            # Create environments
            train_env = self._create_env(self.train_df, params)
            eval_env = self._create_env(self.val_df, params, is_eval=True)
            
            start_time = datetime.now()
            
            # Create model
            model = PPO(
                "MultiInputPolicy",
                train_env,
                verbose=0,
                tensorboard_log=f"./tensorboard/trial_{trial.number}"
            )

            # Setup evaluation callback
            eval_callback = EvalCallback(
                eval_env,
                best_model_save_path=f"./models/trial_{trial.number}",
                log_path=f"./logs/trial_{trial.number}",
                eval_freq=25_000,
                deterministic=True,
                render=False
            )
            
            # Train model
            model.learn(
                total_timesteps=self.n_timesteps,
                callback=eval_callback
            )
            train_env.save(f'./optuna/best_model_trial_{trial.number}/vecnormalize.pkl')
            training_time = (datetime.now() - start_time).total_seconds()
            
            # Get final balance from eval environment
            final_balance = eval_env.get_attr('balance')[0]
            total_trades = eval_env.get_attr('total_trades')[0]
            win_rate = eval_env.get_attr('winning_trades')[0] / max(1, total_trades)
            
            # Ca# Log trial results
            logging.info(f"Trial {trial.number} completed:")
            logging.info(f"Final Balance: ${final_balance:,.2f}")
            logging.info(f"Parameters: {params}")
            logging.info(f"Total Trades: {total_trades}")
            logging.info(f"Win Rate: {win_rate:.2%}")
            logging.info(f"Training Time: {training_time:.1f}s")
            logging.info("-" * 80)

            # Print trial results
            print(f"\nTrial {trial.number} completed:")
            print(f"Final Balance: ${final_balance:,.2f}")
            print(f"Parameters:")
            for key, value in params.items():
                print(f"    {key}: {value}")
            print(f"Total Trades: {total_trades}")
            print(f"Win Rate: {win_rate:.2%}")
            print(f"Training Time: {training_time:.1f}s")
            print("-" * 80)
            
            return final_balance
            
        except Exception as e:
            print(f"Trial {trial.number} failed: {str(e)}")
            return float('-inf')

    def _sample_parameters(self, trial: optuna.Trial) -> Dict:
        """Sample reward parameters for trial."""
        return {
            'realized_pnl_weight': trial.suggest_float('realized_pnl_weight', 0.5, 2.0),
            'unrealized_pnl_weight': trial.suggest_float('unrealized_pnl_weight', 0.3, 1.0),
            'holding_time_threshold': trial.suggest_int('holding_time_threshold', 24, 96),
            'holding_penalty_factor': trial.suggest_float('holding_penalty_factor', -0.0001, 0.0),
            'max_trades_per_day': trial.suggest_int('max_trades_per_day', 3, 12),
            'overtrading_penalty_factor': trial.suggest_float('overtrading_penalty_factor', -0.001, 0.0),
            'win_rate_threshold': trial.suggest_float('win_rate_threshold', 0.3, 0.5),
            'win_rate_bonus_factor': trial.suggest_float('win_rate_bonus_factor', 0.0001, 0.001, log=True),
            'drawdown_penalty_factor': trial.suggest_float('drawdown_penalty_factor', -0.001, 0.0)
        }

    def optimize(self, n_trials: int = 100, n_jobs: int = 6) -> None:
        """Run optimization using Optuna's built-in parallelization."""
        self.study.optimize(
            self.objective,
            n_trials=n_trials,
            n_jobs=n_jobs,  # Number of parallel jobs
            show_progress_bar=True
        )
            
        # Print best trial after completion
        print("\nOptimization completed!")
        print("\nBest trial:")
        trial = self.study.best_trial
        print(f"Value: ${trial.value:,.2f}")
        print("Best parameters:")
        for key, value in trial.params.items():
            print(f"    {key}: {value}")


optimizer = RewardOptimizer(
    train_df=train_df,
    val_df=val_df,
    n_timesteps=300_000
)

optimizer.optimize(n_trials=100, n_jobs=32)