In [None]:
import os
import sys
import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple, List
from pathlib import Path
import threading
from queue import Queue
import time
import logging
from datetime import datetime, timedelta, timezone
import pytz
from dataclasses import dataclass, field
import json

# Trading components
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from apscheduler.schedulers.background import BackgroundScheduler

# OANDA components
from oandapyV20 import API
import oandapyV20.endpoints.positions as positions
import oandapyV20.endpoints.orders as orders
import oandapyV20.endpoints.instruments as instruments
import oandapyV20.endpoints.trades as trades


# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import local components
from trading.environments.forex_env2_flat import ForexTradingEnv
from data_management.preprocessor import DataPreprocessor
from data_management.indicator_manager import IndicatorManager

# Setup logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('trading_system.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger('trading_system')

# OANDA Configuration
OANDA_API_KEY = '9317ace4596d61e3e98b1a53b2342483-45d3ad4084c80b111727a9fada9ef0ff'
OANDA_ACCOUNT_ID = '101-004-30348600-001' #running account
# OANDA_ACCOUNT_ID = '101-004-30348600-002'
OANDA_ENV = 'practice'

# Initialize OANDA client
client = API(access_token=OANDA_API_KEY, environment=OANDA_ENV)




@dataclass
class TradeRecord:
    """Detailed record of a single trade."""
    pair: str
    entry_time: datetime
    exit_time: Optional[datetime]
    entry_price: float
    exit_price: Optional[float]
    position_type: str  # 'LONG' or 'SHORT'
    size: float
    pnl: float
    pnl_percentage: float
    trade_duration: timedelta
    spread_entry: float
    spread_exit: Optional[float]
    model_version: str
    market_session: str
    entry_indicators: Dict[str, float]  # Key indicator values at entry
    exit_indicators: Optional[Dict[str, float]]  # Key indicator values at exit

@dataclass
class PairPerformanceMetrics:
    """Performance metrics for a single currency pair."""
    total_trades: int = 0
    winning_trades: int = 0
    losing_trades: int = 0
    total_pnl: float = 0.0
    peak_balance: float = 0.0
    max_drawdown: float = 0.0
    avg_trade_duration: timedelta = timedelta(0)
    win_rate: float = 0.0
    profit_factor: float = 0.0
    sharpe_ratio: float = 0.0
    model_version: str = ""
    last_retrain_date: Optional[datetime] = None
    performance_by_session: Dict[str, float] = field(default_factory=dict)
    
class PerformanceTracker:
    """Tracks and analyzes trading system performance."""
    
    def __init__(self, base_path: Path):
        self.base_path = base_path
        self.trades_path = base_path / "trades"
        self.metrics_path = base_path / "metrics"
        self.trades_path.mkdir(parents=True, exist_ok=True)
        self.metrics_path.mkdir(parents=True, exist_ok=True)
        
        # Initialize storage
        self.trade_history: Dict[str, List[TradeRecord]] = {}
        self.pair_metrics: Dict[str, PairPerformanceMetrics] = {}
        self.error_log: List[Dict] = []
        self.model_versions: Dict[str, str] = {}
        
        # Performance thresholds for alerts
        self.thresholds = {
            'drawdown_alert': 0.10,  # 10% drawdown
            'win_rate_min': 0.45,    # 45% minimum win rate
            'trade_frequency_max': 50 # Max trades per day
        }
        
    def record_trade(self, trade: TradeRecord) -> None:
        """Record a completed trade and update metrics."""
        pair = trade.pair
        
        # Store trade record
        if pair not in self.trade_history:
            self.trade_history[pair] = []
        self.trade_history[pair].append(trade)
        
        # Update pair metrics
        if pair not in self.pair_metrics:
            self.pair_metrics[pair] = PairPerformanceMetrics()
        
        metrics = self.pair_metrics[pair]
        metrics.total_trades += 1
        metrics.total_pnl += trade.pnl
        
        if trade.pnl > 0:
            metrics.winning_trades += 1
        else:
            metrics.losing_trades += 1
            
        # Update win rate and other metrics
        self._update_pair_metrics(pair)
        
        # Check for performance alerts
        self._check_performance_alerts(pair)
        
    def _update_pair_metrics(self, pair: str) -> None:
        """Update detailed metrics for a currency pair."""
        metrics = self.pair_metrics[pair]
        trades = self.trade_history[pair]
        
        if not trades:
            return
            
        # Calculate basic metrics
        metrics.win_rate = metrics.winning_trades / metrics.total_trades
        
        # Calculate profit factor
        winning_pnl = sum(t.pnl for t in trades if t.pnl > 0)
        losing_pnl = abs(sum(t.pnl for t in trades if t.pnl < 0))
        metrics.profit_factor = winning_pnl / losing_pnl if losing_pnl != 0 else float('inf')
        
        # Calculate drawdown
        cumulative_pnl = np.cumsum([t.pnl for t in trades])
        peak = np.maximum.accumulate(cumulative_pnl)
        drawdown = (peak - cumulative_pnl) / peak
        metrics.max_drawdown = np.max(drawdown)
        
        # Calculate session performance
        session_pnl = {}
        for trade in trades:
            session = trade.market_session
            session_pnl[session] = session_pnl.get(session, 0) + trade.pnl
        metrics.performance_by_session = session_pnl
        
        # Save updated metrics
        self._save_pair_metrics(pair)
        
    def _check_performance_alerts(self, pair: str) -> None:
        """Check for performance issues that require attention."""
        metrics = self.pair_metrics[pair]
        alerts = []
        
        # Check drawdown
        if metrics.max_drawdown >= self.thresholds['drawdown_alert']:
            alerts.append(f"High drawdown alert: {metrics.max_drawdown:.1%}")
            
        # Check win rate
        if metrics.total_trades >= 20 and metrics.win_rate < self.thresholds['win_rate_min']:
            alerts.append(f"Low win rate alert: {metrics.win_rate:.1%}")
            
        # Check trade frequency
        recent_trades = [t for t in self.trade_history[pair] 
                        if t.entry_time > datetime.now() - timedelta(days=1)]
        if len(recent_trades) > self.thresholds['trade_frequency_max']:
            alerts.append("High trade frequency alert")
            
        if alerts:
            logging.warning(f"Performance alerts for {pair}:\n" + "\n".join(alerts))
            
    def analyze_model_performance(self, pair: str) -> pd.DataFrame:
        """Analyze performance metrics by model version."""
        if pair not in self.trade_history:
            return pd.DataFrame()
            
        trades = self.trade_history[pair]
        df = pd.DataFrame([{
            'model_version': t.model_version,
            'entry_time': t.entry_time,
            'pnl': t.pnl,
            'trade_duration': t.trade_duration,
            'market_session': t.market_session
        } for t in trades])
        
        return df.groupby('model_version').agg({
            'pnl': ['count', 'sum', 'mean', 'std'],
            'trade_duration': 'mean'
        })
        
    def get_pair_summary(self, pair: str, lookback_days: int = 30) -> Dict:
        """Get comprehensive performance summary for a pair."""
        if pair not in self.pair_metrics:
            return {}
            
        metrics = self.pair_metrics[pair]
        recent_trades = [t for t in self.trade_history[pair] 
                        if t.entry_time > datetime.now() - timedelta(days=lookback_days)]
        
        return {
            'total_trades': metrics.total_trades,
            'win_rate': metrics.win_rate,
            'total_pnl': metrics.total_pnl,
            'max_drawdown': metrics.max_drawdown,
            'profit_factor': metrics.profit_factor,
            'performance_by_session': metrics.performance_by_session,
            'recent_trades_count': len(recent_trades),
            'model_version': metrics.model_version,
            'last_retrain': metrics.last_retrain_date
        }
        
    def _save_pair_metrics(self, pair: str) -> None:
        """Save pair metrics to disk."""
        metrics = self.pair_metrics[pair]
        
        # Convert to serializable format
        metrics_dict = {
            'total_trades': metrics.total_trades,
            'winning_trades': metrics.winning_trades,
            'total_pnl': metrics.total_pnl,
            'max_drawdown': metrics.max_drawdown,
            'win_rate': metrics.win_rate,
            'profit_factor': metrics.profit_factor,
            'model_version': metrics.model_version,
            'last_retrain_date': metrics.last_retrain_date.isoformat() 
                if metrics.last_retrain_date else None,
            'performance_by_session': metrics.performance_by_session
        }
        
        # Save to file
        metrics_file = self.metrics_path / f"{pair}_metrics.json"
        with open(metrics_file, 'w') as f:
            json.dump(metrics_dict, f, indent=2)
            
    def export_performance_report(self, lookback_days: Optional[int] = None) -> str:
        """Generate a comprehensive performance report."""
        report = ["Trading System Performance Report\n"]
        report.append(f"Generated at: {datetime.now()}\n")
        
        for pair in sorted(self.pair_metrics.keys()):
            metrics = self.pair_metrics[pair]
            trades = self.trade_history[pair]
            
            if lookback_days:
                trades = [t for t in trades 
                         if t.entry_time > datetime.now() - timedelta(days=lookback_days)]
            
            report.append(f"\n{pair} Performance:")
            report.append(f"Total Trades: {metrics.total_trades}")
            report.append(f"Win Rate: {metrics.win_rate:.1%}")
            report.append(f"Total PnL: {metrics.total_pnl:,.2f}")
            report.append(f"Max Drawdown: {metrics.max_drawdown:.1%}")
            report.append(f"Profit Factor: {metrics.profit_factor:.2f}")
            report.append("\nPerformance by Session:")
            
            for session, pnl in metrics.performance_by_session.items():
                report.append(f"  {session}: {pnl:,.2f}")
                
            report.append(f"\nCurrent Model: {metrics.model_version}")
            if metrics.last_retrain_date:
                report.append(f"Last Retrain: {metrics.last_retrain_date}")
                
        return "\n".join(report)

class PositionManager:
    """Manages trading positions with safety features and position tracking."""
    
    def __init__(
        self,
        currency_pairs: Dict[str, float],
        logger: Optional[logging.Logger] = None,
        account_id: str = OANDA_ACCOUNT_ID,
        client: API = API(access_token=OANDA_API_KEY, environment=OANDA_ENV),
    ):
        """
        Initialize the position manager.
        
        Args:
            client: OANDA API client
            account_id: OANDA account ID
            currency_pairs: Dictionary of currency pairs and their position sizes
            logger: Optional logger instance
        """
        self.client = client
        self.account_id = account_id
        self.currency_pairs = currency_pairs
        self.logger = logger or logging.getLogger(__name__)
        self.positions = {}
        self.last_sync_time = None
        
    def close_all_positions(self, confirm: bool = True) -> bool:
        """
        Close all open positions with confirmation option.
        
        Args:
            confirm: If True, requires confirmation before closing positions
            
        Returns:
            bool: True if all positions closed successfully
        """
        try:
            # Get current positions
            r = positions.OpenPositions(accountID=self.account_id)
            response = self.client.request(r)
            open_positions = response.get('positions', [])
            
            if not open_positions:
                self.logger.info("No open positions to close")
                return True
                
            # Show positions and get confirmation if required
            total_positions = len(open_positions)
            if confirm:
                print(f"\nFound {total_positions} open positions:")
                for pos in open_positions:
                    pair = pos['instrument']
                    long_units = float(pos.get('long', {}).get('units', 0))
                    short_units = float(pos.get('short', {}).get('units', 0))
                    print(f"- {pair}: Long: {long_units}, Short: {short_units}")
                    
                confirm_input = input("\nClose all positions? (yes/no): ")
                if confirm_input.lower() != 'yes':
                    self.logger.info("Position closing cancelled by user")
                    return False
            
            # Close positions
            for pos in open_positions:
                pair = pos['instrument']
                
                try:
                    # Close long positions
                    if float(pos.get('long', {}).get('units', 0)) > 0:
                        data = {"longUnits": "ALL"}
                        r = positions.PositionClose(
                            accountID=self.account_id,
                            instrument=pair,
                            data=data
                        )
                        self.client.request(r)
                        self.logger.info(f"Closed long position for {pair}")
                    
                    # Close short positions
                    if float(pos.get('short', {}).get('units', 0)) < 0:
                        data = {"shortUnits": "ALL"}
                        r = positions.PositionClose(
                            accountID=self.account_id,
                            instrument=pair,
                            data=data
                        )
                        self.client.request(r)
                        self.logger.info(f"Closed short position for {pair}")
                        
                    # Small delay to prevent rate limiting
                    time.sleep(0.1)
                    
                except Exception as e:
                    self.logger.error(f"Error closing position for {pair}: {str(e)}")
                    return False
            
            # Verify all positions are closed
            r = positions.OpenPositions(accountID=self.account_id)
            response = self.client.request(r)
            remaining_positions = response.get('positions', [])
            
            if not remaining_positions:
                self.logger.info("All positions successfully closed")
                return True
            else:
                self.logger.warning(
                    f"Some positions remain after closing attempt: {len(remaining_positions)} positions"
                )
                return False
                
        except Exception as e:
            self.logger.error(f"Error in close_all_positions: {str(e)}")
            return False
            
    def cancel_all_orders(self) -> bool:
        """Cancel all pending orders."""
        try:
            # Get all pending orders
            r = orders.OrderList(accountID=self.account_id)
            response = self.client.request(r)
            pending_orders = response.get('orders', [])
            
            if not pending_orders:
                self.logger.info("No pending orders to cancel")
                return True
                
            # Cancel each order
            for order in pending_orders:
                try:
                    r = orders.OrderCancel(
                        accountID=self.account_id,
                        orderID=order['id']
                    )
                    self.client.request(r)
                    self.logger.info(f"Cancelled order {order['id']}")
                    time.sleep(0.1)  # Rate limiting prevention
                    
                except Exception as e:
                    self.logger.error(f"Error cancelling order {order['id']}: {str(e)}")
                    return False
                    
            return True
            
        except Exception as e:
            self.logger.error(f"Error in cancel_all_orders: {str(e)}")
            return False
    
    def emergency_shutdown(self) -> None:
        """
        Emergency shutdown - closes all positions and cancels all orders.
        Returns only after confirming all positions are closed.
        """
        self.logger.warning("Initiating emergency shutdown...")
        
        # First attempt
        success = self.close_all_positions(confirm=False)
        self.cancel_all_orders()
        
        # Retry if necessary
        if not success:
            self.logger.warning("First closing attempt failed, retrying...")
            time.sleep(1)
            success = self.close_all_positions(confirm=False)
            
        # Final verification
        r = positions.OpenPositions(accountID=self.account_id)
        response = self.client.request(r)
        remaining_positions = response.get('positions', [])
        
        if remaining_positions:
            self.logger.error(
                "Emergency shutdown incomplete - some positions remain. "
                "Manual intervention may be required."
            )
        else:
            self.logger.info("Emergency shutdown completed successfully")
            
    def get_position_status(self) -> pd.DataFrame:
        """
        Get detailed status of all positions.
        Returns DataFrame with position information.
        """
        try:
            r = positions.OpenPositions(accountID=self.account_id)
            response = self.client.request(r)
            positions_data = []
            
            for pos in response.get('positions', []):
                pair = pos['instrument']
                long_units = float(pos.get('long', {}).get('units', 0))
                short_units = float(pos.get('short', {}).get('units', 0))
                
                positions_data.append({
                    'pair': pair,
                    'long_units': long_units,
                    'short_units': short_units,
                    'net_position': long_units + short_units,
                    'timestamp': pd.Timestamp.now(tz='UTC')
                })
                
            return pd.DataFrame(positions_data)
            
        except Exception as e:
            self.logger.error(f"Error getting position status: {str(e)}")
            return pd.DataFrame()

class SpreadTracker:
    """Tracks and analyzes spread costs by currency pair and trading session."""
    
    def __init__(self, save_path: str = "spread_history.parquet"):
        self.save_path = Path(save_path)
        self.spreads = pd.DataFrame(columns=[
            'timestamp', 'pair', 'ask', 'bid', 'spread',  # Changed from spread_pips to spread
            'session', 'trade_type'
        ])
        self.load_history()
        
    def load_history(self):
        """Load existing spread history if available."""
        if self.save_path.exists():
            self.spreads = pd.read_parquet(self.save_path)
            
    def get_current_prices(self, pair: str) -> Tuple[float, float]:
        """Get current bid/ask prices from OANDA."""
        params = {
            "count": 1,
            "granularity": "S5",  # 5-second candles for most recent price
            "price": "AB"  # Ask and Bid prices
        }
        r = instruments.InstrumentsCandles(instrument=pair, params=params)
        response = client.request(r)

        
        if not response.get('candles'):
            raise ValueError(f"No price data available for {pair}")
            
        candle = response['candles'][0]

        ask = float(candle['ask']['c'])
        bid = float(candle['bid']['c'])
        return ask, bid
        
    def record_spread(self, pair: str, trade_type: str) -> float:
        """
        Record spread at time of trade execution.
        
        Args:
            pair: Currency pair
            trade_type: 'OPEN' or 'CLOSE'
            
        Returns:
            float: Raw spread (ask - bid)
        """
        try:
            # Get current prices
            ask, bid = self.get_current_prices(pair)
            
            # Calculate raw spread
            spread = ask - bid
            
            # Determine current trading session
            now = pd.Timestamp.now(tz='UTC')
            session = self._get_trading_session(now)
            
            # Record spread
            new_record = pd.DataFrame([{
                'timestamp': now,
                'pair': pair,
                'ask': ask,
                'bid': bid,
                'spread': spread,  # Raw spread value
                'session': session,
                'trade_type': trade_type
            }])
            
            self.spreads = pd.concat([self.spreads, new_record])
            
            # Save updated history
            self.spreads.to_parquet(self.save_path)
            
            logger.info(f"Recorded spread of {spread:.6f} for {pair} "
                    f"during {session} session ({trade_type})")
            
            return spread
            
        except Exception as e:
            logger.error(f"Error recording spread for {pair}: {str(e)}")
            return None
            
    def get_spread_statistics(self, pair: str = None, session: str = None) -> pd.DataFrame:
        """Get spread statistics by pair and/or session."""
        df = self.spreads
        
        if pair:
            df = df[df['pair'] == pair]
        if session:
            df = df[df['session'] == session]
            
        # Simpler aggregation that won't result in NaN
        stats = df.groupby(['pair', 'session']).agg({
            'spread': ['mean', 'std', 'min', 'max', 'count'],
            'timestamp': ['min', 'max']
        }).round(6)  # Round to 6 decimal places for spreads
        
        return stats
        
    def _get_trading_session(self, timestamp: pd.Timestamp) -> str:
        """Determine current trading session."""
        hour = timestamp.hour
        
        # Convert to major session times
        tokyo_hour = (hour + 9) % 24
        ny_hour = (hour - 4) % 24
        
        if 9 <= tokyo_hour < 15:
            return 'TOKYO'
        elif 8 <= hour < 16:
            return 'LONDON'
        elif 8 <= ny_hour < 17:
            return 'NEW_YORK'
        else:
            return 'OFF_HOURS'

# Usage in TradingSystem class:
# class TradingSystem:
#     def __init__(self):
#         # ... existing initialization ...
#         self.spread_tracker = SpreadTracker()
        
#     def execute_trade(self, pair: str, current_position: str, new_position: str):
#         """Execute trade with spread tracking."""
#         try:
#             # Record spread before closing position
#             if current_position != 'NO_POSITION':
#                 spread_close = self.spread_tracker.record_spread(pair, 'CLOSE')
#                 print(f'Closing spread for {pair}: {spread_close:.1f} pips')
                
#             # Record spread before opening position    
#             if new_position != 'NO_POSITION':
#                 spread_open = self.spread_tracker.record_spread(pair, 'OPEN')
#                 print(f'Opening spread for {pair}: {spread_open:.1f} pips')
                
#             # Execute trade as before...
            
#         except Exception as e:
#             logger.error(f"Error executing trade for {pair}: {str(e)}")
#             raise
            
#     def analyze_trading_costs(self) -> None:
#         """Analyze current trading costs."""
#         stats = self.spread_tracker.get_spread_statistics()
#         print("\nSpread Statistics by Pair and Session:")
#         print(stats)
        
#         # Calculate cost impact
#         total_trades = len(self.spread_tracker.spreads)
#         avg_spread_cost = stats['spread_pips']['mean'].mean()
#         print(f"\nAverage spread cost across all pairs: {avg_spread_cost:.1f} pips")
#         print(f"Total trades analyzed: {total_trades}")

# Trading pairs configuration with position sizes
currency_pairs = {
    'EUR_USD': 94_510.0,
    'GBP_USD': 78_500.0,
    'USD_JPY': 100_000.0,
    'USD_CHF': 100_000.0,
    'USD_CAD': 100_000.0,
    'AUD_USD': 153_000.0,
    'NZD_USD': 171_430.0,

    # Cross Pairs
    'EUR_GBP': 94_510,
    'EUR_CHF': 94_510,
    'EUR_JPY': 94_510,
    'EUR_CAD': 94_510,
    'GBP_CHF': 78_500.0,
    'GBP_JPY': 78_500.0,
    'CHF_JPY': 88_100.0,
    'AUD_JPY': 153_000.0,
    'NZD_JPY': 171_430.0,

    # Precious Metals
    'XAU_USD': 38.0,  
    'XAG_USD': 3_266  

}

def get_current_time():
    return datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')

class FastDataManager:
    """High-performance data manager optimized for low-latency trading."""
    
    def __init__(
        self,
        base_storage_path: str,
        max_history_size: int = 10000
    ):
        self.base_storage_path = Path(base_storage_path)
        self.max_history_size = max_history_size
        self.training_features = [
            'close', 'sma_20', 'sma_50', 'rsi', 'macd', 
            'macd_signal', 'macd_hist', 'bb_upper', 'bb_middle', 
            'bb_lower', 'bb_bandwidth', 'bb_percent', 'atr', 
            'plus_di', 'minus_di', 'adx', 'senkou_span_a', 
            'senkou_span_b', 'tenkan_sen', 'kijun_sen'
        ]
        
        # Storage for different data types
        self.raw_data: Dict[str, pd.DataFrame] = {}
        self.normalized_data: Dict[str, pd.DataFrame] = {}
        
        # Thread safety
        self.data_lock = threading.Lock()
        self.save_queue = Queue()
        
        # Initialize components
        self.indicator_manager = IndicatorManager()
        self.data_processor = DataPreprocessor()
        
        # Start save worker
        self.save_worker = threading.Thread(
            target=self._parquet_save_worker,
            daemon=True,
            name="ParquetSaveWorker"
        )
        self.save_worker.start()
    
    def fetch_missing_candles(self, pair: str, last_timestamp: pd.Timestamp) -> pd.DataFrame:
        """Fetch new candles from OANDA."""
        print(f"Fetching missing candles for {pair}...")
        print(f'Fetch missing candles for {pair} - time {get_current_time()}')
        params = {
            "from": last_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ'),
            "granularity": "M5",
            "price": "M"
        }
        
        r = instruments.InstrumentsCandles(instrument=pair, params=params)
        response = client.request(r)
        candles = response.get('candles', [])
        
        if not candles:
            return pd.DataFrame()
            
        df_list = [{
            'timestamp': pd.to_datetime(candle['time'], utc=True),
            'open': float(candle['mid']['o']),
            'high': float(candle['mid']['h']),
            'low': float(candle['mid']['l']),
            'close': float(candle['mid']['c']),
            # 'volume': int(candle['volume'])
        } for candle in candles if candle['complete']]
        
        if not df_list:
            return pd.DataFrame()
            
        df = pd.DataFrame(df_list)
        df.set_index('timestamp', inplace=True)
        # df.index = df.index.tz_localize('UTC')
        print(f"Fetched {len(df)} candles for {pair}. at time {get_current_time()}")
        print(df)
        return df

    def initialize_pair(self, pair: str) -> bool:
        """Initialize data for a pair."""
        try:
            parquet_path = self.base_storage_path / f"{pair}_5T_indics_1H_not_norm.parquet"
            df = pd.read_parquet(parquet_path)
            
            if df.index.tz is None:
                df.index = df.index.tz_localize('UTC')
            #! disable the logic to fetch only last 1000 datapoints for now (1000 might not be enough and check if this changes the values and confuse agent)      
            # if len(df) > self.max_history_size:
            #     df = df.iloc[-self.max_history_size:]
            
            with self.data_lock:
                self.raw_data[pair] = df
                self.normalized_data[pair] = self.data_processor.normalize_simple(df)
                
            # logger.info(f"Initialized data for {pair}, loaded {len(df)} candles")
            return True
            
        except Exception as e:
            logger.error(f"Failed to initialize {pair}: {str(e)}")
            return False

    def update_pair_data(self, pair: str) -> bool:
        """Update data for a pair with safer data concatenation."""
        print(f"Updating data for {pair} - time {get_current_time()}")

        try:
            with self.data_lock:
                if pair not in self.raw_data:
                    raise KeyError(f"Pair {pair} not initialized")
                    
                df = self.raw_data[pair]
                last_timestamp = df.index[-1]
                print(f"Last timestamp for {pair}: {last_timestamp}")

            current_time = pd.Timestamp.now(tz='UTC')
            
            if current_time - last_timestamp >= timedelta(minutes=5):
                print(f"Fetching new data for {pair}")
                new_data = self.fetch_missing_candles(pair, last_timestamp)
                
                if not new_data.empty:
                    with self.data_lock:
                        # Ensure indices are datetime and timezone-aware
                        if df.index.tz is None:
                            df.index = df.index.tz_localize('UTC')
                        if new_data.index.tz is None:
                            new_data.index = new_data.index.tz_localize('UTC')
                        
                        # Combine old and new data
                        combined_df = pd.concat([df, new_data])
                        
                        # Remove any duplicates, keeping the latest version
                        combined_df = combined_df[~combined_df.index.duplicated(keep='last')]
                        
                        # Sort by timestamp
                        combined_df.sort_index(inplace=True)
                        
                        # Calculate indicators on the full dataset
                        try:
                            combined_df_with_indicators = self.indicator_manager.calculate_indicators(combined_df)
                            print(f"Successfully calculated indicators for {pair}")
                            
                            # Update the stored data
                            self.raw_data[pair] = combined_df_with_indicators
                            
                            # Update normalized data
                            self.normalized_data[pair] = self.data_processor.normalize_simple(
                                combined_df_with_indicators
                            )
                            
                            # Queue the save operation
                            try:
                                self.save_queue.put((pair, combined_df_with_indicators))
                                print(f"Data queued for saving for {pair}")
                            except Exception as e:
                                logger.error(f"Error queuing save operation for {pair}: {str(e)}")
                                
                            return True
                            
                        except Exception as e:
                            logger.error(f"Error calculating indicators for {pair}: {str(e)}")
                            raise
                            
                else:
                    logger.info(f"No new data available for {pair}")
                    return False
            
            return False

        except Exception as e:
            logger.error(f"Error updating data for {pair}: {str(e)}")
            raise

    def get_prediction_data(self, pair: str, sequence_length: int, current_position: float) -> np.ndarray:
        """
        Get normalized data sequence for prediction.
        Ensures feature consistency with training environment.
        """
        try:
            with self.data_lock:
                if pair not in self.normalized_data:
                    raise KeyError(f"No data available for {pair}")
                
                df = self.normalized_data[pair]
                
                # Define the exact features used in training
                training_features = [
                    'close', 'sma_20', 'sma_50', 'rsi', 'macd', 
                    'macd_signal', 'macd_hist', 'bb_upper', 'bb_middle', 
                    'bb_lower', 'bb_bandwidth', 'bb_percent', 'atr', 
                    'plus_di', 'minus_di', 'adx', 'senkou_span_a', 
                    'senkou_span_b', 'tenkan_sen', 'kijun_sen'
                ]
                
                # Select only the features used in training
                df_features = df[training_features]
                
                # Get last sequence_length rows
                sequence = df_features.iloc[-sequence_length:].values
                sequence_transposed = sequence.T
                market_features = sequence_transposed.flatten()
                
                # Add position information
                position_info = np.array([current_position])
                observation = np.concatenate([market_features, position_info])
                
                # Validate shape
                expected_size = sequence_length * len(training_features) + 1
                if observation.shape[0] != expected_size:
                    raise ValueError(
                        f"Observation shape mismatch: got {observation.shape[0]}, "
                        f"expected {expected_size}"
                    )
                
                # # Debug information
                # logger.debug(f"Observation construction for {pair}:")
                # logger.debug(f"Number of features: {len(training_features)}")
                # logger.debug(f"Sequence shape: {sequence.shape}")
                # logger.debug(f"Final shape: {observation.shape}")
                
                return observation.astype(np.float32)
                
        except Exception as e:
            logger.error(f"Error constructing prediction data for {pair}: {str(e)}")
            raise

    def _parquet_save_worker(self) -> None:
        """Background worker for parquet saves with improved error handling."""
        while True:
            try:
                pair, df = self.save_queue.get()
                if pair is None:
                    break
                    
                parquet_path = self.base_storage_path / f"{pair}_5T_indics_1H_not_norm.parquet"
                
                # Create a backup of the existing file
                if parquet_path.exists():
                    backup_path = parquet_path.with_suffix('.parquet.backup')
                    parquet_path.rename(backup_path)
                
                try:
                    # Save the new data
                    df.to_parquet(parquet_path)
                    
                    # If save successful, remove backup
                    if backup_path.exists():
                        backup_path.unlink()
                        
                    logger.info(f"Successfully saved data for {pair}")
                    
                except Exception as e:
                    # If save fails, restore from backup
                    if backup_path.exists():
                        backup_path.rename(parquet_path)
                    logger.error(f"Error saving data for {pair}, restored from backup: {str(e)}")
                    
            except Exception as e:
                logger.error(f"Error in save worker: {str(e)}")
            finally:
                self.save_queue.task_done()


class TradingSystem:
    """Main trading system coordination."""
    
    def __init__(self):
        self.data_manager = None
        self.models = {}
        self.positions = {}
        self.positions_lock = threading.Lock()
        self.position_entry_prices = {}
        self.position_entry_times = {}
        self.position_entry_indicators = {}
        self.position_entry_spreads = {}
        self.start_time = datetime.now(timezone.utc)

    def position_to_float(self, position_type: str) -> float:
        """Convert position type to float representation."""
        position_map = {
            'LONG': 1.0,
            'SHORT': -1.0,
            'NO_POSITION': 0.0
        }
        return position_map.get(position_type, 0.0)

    def initialize(self):
        """Initialize the trading system."""
        logger.info("Initializing trading system...")
        
        # Initialize data manager
        self.data_manager = FastDataManager(
            base_storage_path="/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_not_norm/to_test_deploy"
        )
        
        # Load models and initialize data for each pair
        for pair in currency_pairs:
            try:
                if not self.data_manager.initialize_pair(pair):
                    continue
                    
                model_path = f'/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_true_cost/models_and_vecs/{pair}_best_model'
                env_path = f'/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_true_cost/models_and_vecs/{pair}_vec_normalize.pkl'
                
                # Create environment for loading model
                vec_env = DummyVecEnv([lambda: ForexTradingEnv(
                    self.data_manager.raw_data[pair], pair
                )])
                
                # Load environment normalization
                env = VecNormalize.load(env_path, vec_env)
                env.training = False
                env.norm_reward = False
                
                # Load the model
                model = PPO.load(model_path, env=env)
                self.models[pair] = model
                # print(f"Loaded model for {pair}")
                # logger.info(f"Models loaded: {list(self.models.keys())}")
                # logger.info(f"Initialized model for {pair}")
                
            except Exception as e:
                logger.error(f"Error initializing {pair}: {str(e)}")
                continue
        
        # Sync initial positions
        self.sync_positions()
        # print(f"Initialized done")
        # print(f'self.models after init at time {get_current_time()}: {self.models}')
        
    def _make_prediction(self, pair: str, observation: np.ndarray) -> str:
        """
        Make a prediction using the loaded model.
        Returns position type ('NO_POSITION', 'LONG', or 'SHORT').
        """
        try:
            if pair not in self.models:
                raise KeyError(f"No model loaded for {pair}")

            # Reshape observation for model input
            model_input = observation.reshape(1, -1)
            
            # Get model's expected shape
            expected_shape = self.models[pair].policy.observation_space.shape[0]
            actual_shape = observation.shape[0]
            
            if actual_shape != expected_shape:
                raise ValueError(
                    f"Observation shape mismatch for {pair}: "
                    f"expected {expected_shape}, got {actual_shape}"
                )
            
            # Get prediction
            action, _ = self.models[pair].predict(model_input, deterministic=True)
            
            # Map action to position type
            action_map = {0: 'NO_POSITION', 1: 'LONG', 2: 'SHORT'}
            return action_map[action[0]]
            
        except Exception as e:
            logger.error(f"Error making prediction for {pair}: {str(e)}")
            # Return current position on error to avoid unwanted changes
            return self.positions.get(pair, 'NO_POSITION')
        
    def sync_positions(self):
        """Synchronize positions with broker."""
        try:
            r = positions.OpenPositions(accountID=OANDA_ACCOUNT_ID)
            response = client.request(r)
            print(f'sync_positions response: {response}')
            
            with self.positions_lock:
                self.positions.clear()
                for pos in response.get('positions', []):
                    pair = pos['instrument']
                    if pair in currency_pairs:
                        if float(pos.get('long', {}).get('units', 0)) > 0:
                            self.positions[pair] = 'LONG'
                        elif float(pos.get('short', {}).get('units', 0)) < 0:
                            self.positions[pair] = 'SHORT'
                        else:
                            self.positions[pair] = 'NO_POSITION'
            print(f'self.postions after sync {self.positions} at time {get_current_time()}')            
            logger.info("Positions synchronized")
            
        except Exception as e:
            logger.error(f"Error syncing positions: {str(e)}")
            raise

        
    def execute_trade(self, pair: str, current_position: str, new_position: str):
        """Execute a trade."""
        print(f'execute_trade called for {pair} with current_position {current_position} and new_position {new_position} at time {get_current_time()}')
        try:
            # Close existing position if any
            if current_position != 'NO_POSITION':
                print(f'Closing existing position for {pair}')
   
                self.close_position(pair, current_position)
            
            # Open new position if not moving to neutral
            if new_position != 'NO_POSITION':
                print(f'Opening new position for {pair}')
          
                self.open_position(pair, new_position)
            
            # Update position storage
            with self.positions_lock:
                self.positions[pair] = new_position
                
            logger.info(f"Executed trade for {pair}: {current_position} -> {new_position}")
            
        except Exception as e:
            logger.error(f"Error executing trade for {pair}: {str(e)}")
            raise

    def open_position(self, pair: str, position_type: str):
        print(f'open_position called for {pair} with position_type {position_type}')
        """Open a new position."""
        units = currency_pairs[pair]
        if position_type == 'SHORT':
            units = -units
            
        data = {
            "order": {
                "instrument": pair,
                "units": str(units),
                "type": "MARKET",
                "positionFill": "DEFAULT"
            }
        }
        
        r = orders.OrderCreate(accountID=OANDA_ACCOUNT_ID, data=data)
        client.request(r)
        print(f'open_position response: {r} at time {get_current_time()}')

    def close_position(self, pair: str, position_type: str):
        print(f'close_position called for {pair} with position_type {position_type}')
        """Close an existing position."""
        data = {
            "longUnits": "ALL"
        } if position_type == 'LONG' else {
            "shortUnits": "ALL"
        }
        
        r = positions.PositionClose(
            accountID=OANDA_ACCOUNT_ID,
            instrument=pair,
            data=data
        )
        client.request(r)
        print(f'close_position response: {r} at time {get_current_time()}')

    def trading_cycle(self):
        """Execute one trading cycle."""
        logger.info("Starting trading cycle")
        print("Available models:", list(self.models.keys()))

        for pair in currency_pairs:
            try:
                if pair not in self.models:
                    logger.error(f"No model loaded for {pair}, skipping.")
                    continue

                # Update market data
                if self.data_manager.update_pair_data(pair):
                    # Get current position
                    with self.positions_lock:
                        current_position_type = self.positions.get(pair, 'NO_POSITION')
                        print(f'current_position_type {current_position_type}')
                    
                    # Convert position to float for observation
                    current_position_float = self.position_to_float(current_position_type)
                    print(f'current_position_float {current_position_float} for pair {pair} at time {get_current_time()}')

                    # Get prediction data
                    observation = self.data_manager.get_prediction_data(
                        pair=pair,
                        sequence_length=5,
                        current_position=current_position_float
                    )

                    # Get model prediction
                    action_name = self._make_prediction(pair, observation)
                    print(f'action_name {action_name} for pair {pair}')

                    # Execute trade if position change needed
                    if current_position_type != action_name:
                        print(f'execute_trade called for {pair} with current_position_type {current_position_type} and action_name {action_name}')
                        self.execute_trade(pair, current_position_type, action_name)

            except Exception as e:
                logger.error(f"Error in trading cycle for {pair}: {str(e)}")
                continue

    def run(self):
        """Run the trading system."""
        try:
            self.initialize()
            print('def run _ self.initialize() complete')
            
            scheduler = BackgroundScheduler()
            scheduler.add_job(
                self.trading_cycle,
                'cron',
                minute='*/5',
                second=0
            )
            scheduler.start()
            
            logger.info("Trading system started")
            
            while True:
                time.sleep(60)
                
        except (KeyboardInterrupt, SystemExit):
            logger.info("Shutting down trading system...")
            scheduler.shutdown()
            self.data_manager.shutdown()
            logger.info("Trading system shutdown complete")
            
        except Exception as e:
            logger.error(f"Fatal error in trading system: {str(e)}")
            raise





In [None]:
try:
    # Configure logging
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler('trading_system.log'),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger('trading_system')
    
    # Start trading system
    logger.info("Starting trading system...")
    trading_system = TradingSystem()
    trading_system.run()
    
except KeyboardInterrupt:
    logger.info("Received shutdown signal. Initiating graceful shutdown...")
    trading_system.data_manager.shutdown()
    logger.info("Trading system shutdown complete.")
    
except Exception as e:
    logger.error(f"Fatal error: {str(e)}", exc_info=True)
    raise

In [None]:
import pandas as pd
import os
import os
import sys
import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple, List
from pathlib import Path
import threading
from queue import Queue
import time
import logging
from datetime import datetime, timedelta, timezone
import pytz
from dataclasses import dataclass, field
import json

# Trading components
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from apscheduler.schedulers.background import BackgroundScheduler

# OANDA components
from oandapyV20 import API
import oandapyV20.endpoints.positions as positions
import oandapyV20.endpoints.orders as orders
import oandapyV20.endpoints.instruments as instruments
import oandapyV20.endpoints.trades as trades


# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import local components
from trading.environments.forex_env2_flat import ForexTradingEnv
from data_management.preprocessor import DataPreprocessor
from data_management.indicator_manager import IndicatorManager

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"




# OANDA Configuration
OANDA_API_KEY = '9317ace4596d61e3e98b1a53b2342483-45d3ad4084c80b111727a9fada9ef0ff'
OANDA_ACCOUNT_ID = '101-004-30348600-001' #running account
# OANDA_ACCOUNT_ID = '101-004-30348600-002'
OANDA_ENV = 'practice'

# Initialize OANDA client
client = API(access_token=OANDA_API_KEY, environment=OANDA_ENV)

def fetch_missing_candles( pair: str, last_timestamp: pd.Timestamp) -> pd.DataFrame:
    """Fetch new candles from OANDA."""

    params = {
        "from": last_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ'),
        "granularity": "M5",
        "price": "M"
    }
    
    r = instruments.InstrumentsCandles(instrument=pair, params=params)
    response = client.request(r)
    candles = response.get('candles', [])
    print('CANDLE PRINT')
    print(candles)
    print('CANDLE PRINT')
    
    if not candles:
        return pd.DataFrame()
        
    df_list = [{
        'timestamp': pd.to_datetime(candle['time'], utc=True),
        'open': float(candle['mid']['o']),
        'high': float(candle['mid']['h']),
        'low': float(candle['mid']['l']),
        'close': float(candle['mid']['c']),
        # 'volume': int(candle['volume'])
    } for candle in candles if candle['complete']]
    
    if not df_list:
        return pd.DataFrame()
        
    df = pd.DataFrame(df_list)
    df.set_index('timestamp', inplace=True)
    # df.index = df.index.tz_localize('UTC')
  
    print(df)
    return df

In [None]:
# pair = 'EUR_USD'
# base_storage_path="/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_not_norm/to_test_deploy"
# currency = f"{pair}_5T_indics_1H_not_norm.parquet"
# path = os.path.join(base_storage_path, currency)
# df = pd.read_parquet(path)
# # df

# last_timestamp = df.index[-1]
# new_timestamp = pd.Timestamp(last_timestamp) + pd.Timedelta(hours=11,minutes=30)
# new_timestamp
# candles = fetch_missing_candles(pair=pair,last_timestamp=new_timestamp)
# candles

In [None]:
# import oandapyV20
# import oandapyV20.endpoints.accounts as accounts
# OANDA_API_KEY = '9317ace4596d61e3e98b1a53b2342483-45d3ad4084c80b111727a9fada9ef0ff'

# api = oandapyV20.API(access_token=OANDA_API_KEY)
# r = accounts.AccountList()
# api.request(r)
# print(r.response)

In [None]:
# pair = 'XAU_USD'
# spread_tracker = SpreadTracker()
# spread_close = spread_tracker.record_spread(pair, 'CLOSE')
# stats = spread_tracker.get_spread_statistics()
# stats

In [None]:
   # Keep all your existing methods (sync_positions, execute_trade, etc.)
    # They remain unchanged

# class TradingSystem:
#     """Main trading system coordination."""
    
#     def __init__(self):
#         self.data_manager = None
#         self.models = {}
#         self.positions = {}
#         self.positions_lock = threading.Lock()
#         self.performance_tracker = PerformanceTracker(
#             base_path=Path("./trading_performance")
#         )


#     def position_to_float(self, position_type: str) -> float:
#         """Convert position type to float representation."""
#         position_map = {
#             'LONG': 1.0,
#             'SHORT': -1.0,
#             'NO_POSITION': 0.0
#         }
#         return position_map.get(position_type, 0.0)
        
#     def initialize(self):
#         """Initialize the trading system."""
#         logger.info("Initializing trading system...")
        
#         # Initialize data manager
#         self.data_manager = FastDataManager(
#             base_storage_path="/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_not_norm/to_test_deploy"
#         )
        
#         # Load models and initialize data for each pair
#         for pair in currency_pairs:
#             try:
#                 if not self.data_manager.initialize_pair(pair):
#                     continue
                    
#                 model_path = f'/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_true_cost/models_and_vecs/{pair}_best_model'
#                 env_path = f'/Volumes/ssd_fat2/ai6_trading_bot/datasets/5min/best_dataframes_true_cost/models_and_vecs/{pair}_vec_normalize.pkl'
                
#                 vec_env = DummyVecEnv([lambda: ForexTradingEnv(
#                     self.data_manager.raw_data[pair], pair
#                 )])
                
#                 env = VecNormalize.load(env_path, vec_env)
#                 env.training = False
#                 env.norm_reward = False
                
#                 model = PPO.load(model_path, env=env)
#                 self.models[pair] = model
#                 print(f"Loaded model for {pair}")
#                 logger.info(f"Models loaded: {list(self.models.keys())}")

                
#                 logger.info(f"Initialized model for {pair}")
                
#             except Exception as e:
#                 logger.error(f"Error initializing {pair}: {str(e)}")
#                 continue
        
#         # Sync initial positions
#         self.sync_positions()
#         print(f"Initialized done")
#         print(f'self.models after init at time {get_current_time()}: {self.models}')
        
#     def sync_positions(self):
#         """Synchronize positions with broker."""
#         try:
#             r = positions.OpenPositions(accountID=OANDA_ACCOUNT_ID)
#             response = client.request(r)
#             print(f'sync_positions response: {response}')
            
#             with self.positions_lock:
#                 self.positions.clear()
#                 for pos in response.get('positions', []):
#                     pair = pos['instrument']
#                     if pair in currency_pairs:
#                         if float(pos.get('long', {}).get('units', 0)) > 0:
#                             self.positions[pair] = 'LONG'
#                         elif float(pos.get('short', {}).get('units', 0)) < 0:
#                             self.positions[pair] = 'SHORT'
#                         else:
#                             self.positions[pair] = 'NO_POSITION'
#             print(f'self.postions after sync {self.positions} at time {get_current_time()}')            
#             logger.info("Positions synchronized")
            
#         except Exception as e:
#             logger.error(f"Error syncing positions: {str(e)}")
#             raise

#     def execute_trade(self, pair: str, current_position: str, new_position: str):
#         """Execute a trade."""
#         print(f'execute_trade called for {pair} with current_position {current_position} and new_position {new_position} at time {get_current_time()}')
#         try:
#             # Close existing position if any
#             if current_position != 'NO_POSITION':
#                 print(f'Closing existing position for {pair}')
#                 #! Trading disabled for now
#                 # self.close_position(pair, current_position)
            
#             # Open new position if not moving to neutral
#             if new_position != 'NO_POSITION':
#                 print(f'Opening new position for {pair}')
#                 #! Trading disabled for now
#                 # self.open_position(pair, new_position)
            
#             # Update position storage
#             with self.positions_lock:
#                 self.positions[pair] = new_position
                
#             logger.info(f"Executed trade for {pair}: {current_position} -> {new_position}")
            
#         except Exception as e:
#             logger.error(f"Error executing trade for {pair}: {str(e)}")
#             raise

#     def open_position(self, pair: str, position_type: str):
#         print(f'open_position called for {pair} with position_type {position_type}')
#         """Open a new position."""
#         units = currency_pairs[pair]
#         if position_type == 'SHORT':
#             units = -units
            
#         data = {
#             "order": {
#                 "instrument": pair,
#                 "units": str(units),
#                 "type": "MARKET",
#                 "positionFill": "DEFAULT"
#             }
#         }
        
#         r = orders.OrderCreate(accountID=OANDA_ACCOUNT_ID, data=data)
#         client.request(r)
#         print(f'open_position response: {r} at time {get_current_time()}')

#     def close_position(self, pair: str, position_type: str):
#         print(f'close_position called for {pair} with position_type {position_type}')
#         """Close an existing position."""
#         data = {
#             "longUnits": "ALL"
#         } if position_type == 'LONG' else {
#             "shortUnits": "ALL"
#         }
        
#         r = positions.PositionClose(
#             accountID=OANDA_ACCOUNT_ID,
#             instrument=pair,
#             data=data
#         )
#         client.request(r)
#         print(f'close_position response: {r} at time {get_current_time()}')

#     def trading_cycle(self):
#         print(f'!!! trading_cycle called')
#         print(f"Available models: {list(self.models.keys())}")

#         """Execute one trading cycle with updated observation handling."""
#         logger.info("Starting trading cycle")
        
#         for pair in currency_pairs:
#             try:
#                 if pair not in self.models:
#                     logger.error(f"No model loaded for {pair}, skipping.")
#                     continue
#                 if self.data_manager.update_pair_data(pair):
#                     # Get current position
#                     with self.positions_lock:
#                         current_position_type = self.positions.get(pair, 'NO_POSITION')
#                         print(f'current_position_type {current_position_type}')
#                     current_position_float = self.position_to_float(current_position_type)
#                     print(f'current_position_float {current_position_float} for pair {pair} at time {get_current_time()}')
                    
#                     # Get prediction data with current position
#                     sequence = self.data_manager.get_prediction_data(
#                         pair=pair,
#                         sequence_length=5,
#                         current_position=current_position_float
#                     )
                    
#                     # Reshape for model
#                     obs_array = sequence.reshape((1, -1))
                    
#                     # Get model prediction
#                     model = self.models[pair]
#                     obs_array = model.env.normalize_obs(obs_array)
#                     action, _ = model.predict(obs_array, deterministic=True)
                    
#                     # Convert action to position type
#                     action_name = {0: 'NO_POSITION', 1: 'LONG', 2: 'SHORT'}[action[0]]
#                     print(f'action_name {action_name} for pair {pair}')
                    
#                     # Execute trade if position change needed
#                     if current_position_type != action_name:
#                         print(f'execute_trade called for {pair} with current_position_type {current_position_type} and action_name {action_name}')
#                         self.execute_trade(pair, current_position_type, action_name)
                        
#                         # Update position tracking
#                         with self.positions_lock:
#                             self.positions[pair] = action_name
            
#             except Exception as e:
#                 logger.error(f"Error in trading cycle for {pair}: {str(e)}")
#                 continue

#     def run(self):
#         """Run the trading system."""
#         try:
#             self.initialize()
#             print('def run _ self.initialize() complete')
            
#             scheduler = BackgroundScheduler()
#             scheduler.add_job(
#                 self.trading_cycle,
#                 'cron',
#                 minute='*/5',
#                 second=0
#             )
#             scheduler.start()
            
#             logger.info("Trading system started")
            
#             while True:
#                 time.sleep(60)
                
#         except (KeyboardInterrupt, SystemExit):
#             logger.info("Shutting down trading system...")
#             scheduler.shutdown()
#             self.data_manager.shutdown()
#             logger.info("Trading system shutdown complete")
            
#         except Exception as e:
#             logger.error(f"Fatal error in trading system: {str(e)}")
#             raise



In [None]:
import os
import sys
import numpy as np
import pandas as pd
from typing import Dict, Optional, Tuple, List
from pathlib import Path
import threading
from queue import Queue
import time
import logging
from datetime import datetime, timedelta, timezone
import pytz
from dataclasses import dataclass, field
import json

# Trading components
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from apscheduler.schedulers.background import BackgroundScheduler

# OANDA components
from oandapyV20 import API
import oandapyV20.endpoints.positions as positions
import oandapyV20.endpoints.orders as orders
import oandapyV20.endpoints.instruments as instruments
import oandapyV20.endpoints.trades as trades


# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

# Import local components
from trading.environments.forex_env2_flat import ForexTradingEnv
from data_management.preprocessor import DataPreprocessor
from data_management.indicator_manager import IndicatorManager


class FastDataManager:
    """High-performance data manager optimized for low-latency trading."""
    
    def __init__(
        self,
        base_storage_path: str,
        max_history_size: int = 10000
    ):
        self.base_storage_path = Path(base_storage_path)
        self.max_history_size = max_history_size
        self.training_features = [
            'close', 'sma_20', 'sma_50', 'rsi', 'macd', 
            'macd_signal', 'macd_hist', 'bb_upper', 'bb_middle', 
            'bb_lower', 'bb_bandwidth', 'bb_percent', 'atr', 
            'plus_di', 'minus_di', 'adx', 'senkou_span_a', 
            'senkou_span_b', 'tenkan_sen', 'kijun_sen'
        ]
        
        # Storage for different data types
        self.raw_data: Dict[str, pd.DataFrame] = {}
        self.normalized_data: Dict[str, pd.DataFrame] = {}
        
        # Thread safety
        self.data_lock = threading.Lock()
        self.save_queue = Queue()
        
        # Initialize components
        self.indicator_manager = IndicatorManager()
        self.data_processor = DataPreprocessor()
        
        # Start save worker
        self.save_worker = threading.Thread(
            target=self._parquet_save_worker,
            daemon=True,
            name="ParquetSaveWorker"
        )
        self.save_worker.start()
    
    def fetch_missing_candles(self, pair: str, last_timestamp: pd.Timestamp) -> pd.DataFrame:
        """Fetch new candles from OANDA."""
        print(f"Fetching missing candles for {pair}...")
        print(f'Fetch missing candles for {pair} - time {get_current_time()}')
        params = {
            "from": last_timestamp.strftime('%Y-%m-%dT%H:%M:%SZ'),
            "granularity": "M5",
            "price": "M"
        }
        
        r = instruments.InstrumentsCandles(instrument=pair, params=params)
        response = client.request(r)
        candles = response.get('candles', [])
        
        if not candles:
            return pd.DataFrame()
            
        df_list = [{
            'timestamp': pd.to_datetime(candle['time'], utc=True),
            'open': float(candle['mid']['o']),
            'high': float(candle['mid']['h']),
            'low': float(candle['mid']['l']),
            'close': float(candle['mid']['c']),
            # 'volume': int(candle['volume'])
        } for candle in candles if candle['complete']]
        
        if not df_list:
            return pd.DataFrame()
            
        df = pd.DataFrame(df_list)
        df.set_index('timestamp', inplace=True)
        # df.index = df.index.tz_localize('UTC')
        print(f"Fetched {len(df)} candles for {pair}. at time {get_current_time()}")
        print(df)
        return df

    def initialize_pair(self, pair: str) -> bool:
        """Initialize data for a pair."""
        try:
            parquet_path = self.base_storage_path / f"{pair}_5T_indics_1H_not_norm.parquet"
            df = pd.read_parquet(parquet_path)
            
            if df.index.tz is None:
                df.index = df.index.tz_localize('UTC')
            #! disable the logic to fetch only last 1000 datapoints for now (1000 might not be enough and check if this changes the values and confuse agent)      
            # if len(df) > self.max_history_size:
            #     df = df.iloc[-self.max_history_size:]
            
            with self.data_lock:
                self.raw_data[pair] = df
                self.normalized_data[pair] = self.data_processor.normalize_simple(df)
                
            # logger.info(f"Initialized data for {pair}, loaded {len(df)} candles")
            return True
            
        except Exception as e:
            logger.error(f"Failed to initialize {pair}: {str(e)}")
            return False

    def update_pair_data(self, pair: str) -> bool:
        """Update data for a pair with safer data concatenation."""
        print(f"Updating data for {pair} - time {get_current_time()}")

        try:
            with self.data_lock:
                if pair not in self.raw_data:
                    raise KeyError(f"Pair {pair} not initialized")
                    
                df = self.raw_data[pair]
                print(df)
                last_timestamp = df.index[-1]
                print(f"Last timestamp for {pair}: {last_timestamp}")

            current_time = pd.Timestamp.now(tz='UTC')
            
            if current_time - last_timestamp >= timedelta(minutes=5):
                print(f"Fetching new data for {pair}")
                return
                new_data = self.fetch_missing_candles(pair, last_timestamp)
                
                if not new_data.empty:
                    with self.data_lock:
                        # Ensure indices are datetime and timezone-aware
                        if df.index.tz is None:
                            df.index = df.index.tz_localize('UTC')
                        if new_data.index.tz is None:
                            new_data.index = new_data.index.tz_localize('UTC')
                        
                        # Combine old and new data
                        combined_df = pd.concat([df, new_data])
                        
                        # Remove any duplicates, keeping the latest version
                        combined_df = combined_df[~combined_df.index.duplicated(keep='last')]
                        
                        # Sort by timestamp
                        combined_df.sort_index(inplace=True)
                        
                        # Calculate indicators on the full dataset
                        try:
                            combined_df_with_indicators = self.indicator_manager.calculate_indicators(combined_df)
                            print(f"Successfully calculated indicators for {pair}")
                            
                            # Update the stored data
                            self.raw_data[pair] = combined_df_with_indicators
                            
                            # Update normalized data
                            self.normalized_data[pair] = self.data_processor.normalize_simple(
                                combined_df_with_indicators
                            )
                            
                            # Queue the save operation
                            try:
                                self.save_queue.put((pair, combined_df_with_indicators))
                                print(f"Data queued for saving for {pair}")
                            except Exception as e:
                                logger.error(f"Error queuing save operation for {pair}: {str(e)}")
                                
                            return True
                            
                        except Exception as e:
                            logger.error(f"Error calculating indicators for {pair}: {str(e)}")
                            raise
                            
                else:
                    logger.info(f"No new data available for {pair}")
                    return False
            
            return False

        except Exception as e:
            logger.error(f"Error updating data for {pair}: {str(e)}")
            raise