In [23]:
"""
BTC Backtesting Script with Historical Data Collection
Collects BTC historical data from Alpaca, stores in TimescaleDB, and runs backtests
"""

import pandas as pd
import numpy as np
import time
from datetime import datetime, timedelta
import pytz
import logging
import sys
import os
from typing import Optional, Dict, Any
import warnings

# Suppress pandas SettingWithCopyWarning - we handle copying explicitly
pd.options.mode.chained_assignment = None
warnings.filterwarnings('ignore', category=pd.errors.SettingWithCopyWarning)

# Modern Alpaca imports
from alpaca.trading.client import TradingClient
from alpaca.data.historical.crypto import CryptoHistoricalDataClient
from alpaca.data.requests import CryptoBarsRequest
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit

# Observability imports
from prometheus_client import Counter, Gauge, start_http_server, REGISTRY
import json
import requests

# Loki logging
from pythonjsonlogger import jsonlogger

# Database (TimescaleDB)
import psycopg2
from psycopg2.pool import SimpleConnectionPool

# MLflow
import mlflow

# Configure logging
class LokiHandler(logging.Handler):
    """Custom handler to send logs to Loki"""
    def __init__(self, loki_url, labels=None):
        super().__init__()
        self.loki_url = loki_url
        self.labels = labels or {}
        self.session = requests.Session()
        
    def emit(self, record):
        try:
            log_entry = {
                "streams": [{
                    "stream": {
                        "job": "btc_backtest",
                        "level": record.levelname.lower(),
                        **self.labels
                    },
                    "values": [[
                        str(int(record.created * 1e9)),
                        json.dumps({
                            "message": self.format(record),
                            "level": record.levelname,
                            "logger": record.name
                        })
                    ]]
                }]
            }
            response = self.session.post(self.loki_url, json=log_entry, timeout=5)
            response.raise_for_status()
        except Exception:
            pass

def setup_logging(level=logging.INFO, loki_url=None, loki_labels=None):
    """Setup logging"""
    logger = logging.getLogger()
    logger.setLevel(level)
    logger.handlers.clear()
    
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    
    stream_handler = logging.StreamHandler(sys.stdout)
    stream_handler.setFormatter(formatter)
    logger.addHandler(stream_handler)
    
    if loki_url:
        try:
            loki_handler = LokiHandler(loki_url, loki_labels)
            loki_handler.setFormatter(formatter)
            logger.addHandler(loki_handler)
        except Exception:
            pass
    
    return logger

logger = setup_logging()

# Database connection pool for TimescaleDB
class TimescaleDB:
    """TimescaleDB connection and operations"""
    def __init__(self, connection_string=None):
        self.connection_string = connection_string
        self.pool = None
        if connection_string:
            logger.info(f"Attempting to connect to TimescaleDB...")
            logger.debug(f"Connection string: {self._mask_connection_string(connection_string)}")
            try:
                # Test connection first with a simple connection
                test_conn = psycopg2.connect(connection_string)
                test_conn.close()
                logger.info("✓ Database connection test successful")
                
                # Create connection pool
                self.pool = SimpleConnectionPool(1, 10, connection_string)
                logger.info("✓ Connection pool created")
                
                # Initialize schema
                self._init_schema()
                logger.info("✓ TimescaleDB connection pool created successfully")
            except psycopg2.OperationalError as e:
                error_msg = str(e)
                if "too many clients" in error_msg.lower():
                    logger.error(f"❌ Database connection failed: Too many connections!")
                    logger.error(f"   The database has reached its maximum connection limit.")
                    logger.error(f"   Solutions:")
                    logger.error(f"   1. Close other database connections (other scripts, notebooks, etc.)")
                    logger.error(f"   2. Restart the database service")
                    logger.error(f"   3. Increase max_connections in PostgreSQL config")
                    logger.error(f"   4. Wait a few minutes for idle connections to timeout")
                else:
                    logger.error(f"❌ Database connection failed (OperationalError): {e}")
                    logger.error(f"   Check if database is running and accessible")
                    logger.error(f"   If running outside Docker, try: postgresql://user:password@localhost:5432/database")
                    logger.error(f"   If running in Docker, ensure hostname 'timescaledb' is reachable")
                self.pool = None
            except Exception as e:
                logger.error(f"❌ Failed to connect to TimescaleDB: {e}")
                logger.error(f"   Connection string format: postgresql://user:password@host:port/database")
                logger.error(f"   Error type: {type(e).__name__}")
                import traceback
                logger.debug(f"   Full traceback: {traceback.format_exc()}")
                self.pool = None
    
    def _mask_connection_string(self, conn_str):
        """Mask password in connection string for logging"""
        try:
            from urllib.parse import urlparse, urlunparse
            parsed = urlparse(conn_str)
            if parsed.password:
                masked = parsed._replace(netloc=f"{parsed.username}:****@{parsed.hostname}:{parsed.port or 5432}")
                return urlunparse(masked)
            return conn_str
        except:
            # If parsing fails, just return a masked version
            if '@' in conn_str:
                parts = conn_str.split('@')
                if ':' in parts[0]:
                    user_pass = parts[0].split(':')
                    if len(user_pass) == 2:
                        return f"{user_pass[0]}:****@{parts[1]}"
            return "****"
    
    def test_connection(self):
        """Test database connection and return True if successful"""
        if not self.pool:
            logger.warning("Cannot test connection: pool is None")
            return False
        
        conn = None
        try:
            conn = self.pool.getconn()
            cur = conn.cursor()
            cur.execute("SELECT version();")
            version = cur.fetchone()[0]
            logger.info(f"✓ Database connection test successful")
            logger.debug(f"   PostgreSQL version: {version[:50]}...")
            return True
        except Exception as e:
            logger.error(f"❌ Connection test failed: {e}")
            return False
        finally:
            if conn and self.pool:
                self.pool.putconn(conn)
    
    def close_all_connections(self):
        """Close all connections in the pool"""
        if self.pool:
            try:
                self.pool.closeall()
                logger.info("Closed all database connections")
            except Exception as e:
                logger.warning(f"Error closing connections: {e}")
    
    def _init_schema(self):
        """Initialize database schema"""
        if not self.pool:
            logger.error("Cannot initialize schema: database pool is not available")
            return
        
        conn = self.pool.getconn()
        try:
            cur = conn.cursor()
            
            # Create trades table
            cur.execute("""
                CREATE TABLE IF NOT EXISTS trades (
                    id SERIAL PRIMARY KEY,
                    timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                    symbol VARCHAR(50) NOT NULL,
                    direction VARCHAR(10) NOT NULL,
                    quantity DECIMAL(18, 8) NOT NULL,
                    entry_price DECIMAL(18, 2) NOT NULL,
                    stop_loss DECIMAL(18, 2),
                    take_profit DECIMAL(18, 2),
                    atr_value DECIMAL(18, 2),
                    status VARCHAR(20) DEFAULT 'open',
                    exit_price DECIMAL(18, 2),
                    exit_timestamp TIMESTAMPTZ,
                    pnl DECIMAL(18, 2),
                    daily_trade_number INTEGER,
                    mlflow_run_id VARCHAR(255)
                );
            """)
            
            # Create backtest_results table for storing metrics
            cur.execute("""
                CREATE TABLE IF NOT EXISTS backtest_results (
                    id SERIAL PRIMARY KEY,
                    timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                    start_date DATE NOT NULL,
                    end_date DATE NOT NULL,
                    symbol VARCHAR(50) NOT NULL,
                    timeframe VARCHAR(10) NOT NULL,
                    total_trades INTEGER,
                    winning_trades INTEGER,
                    losing_trades INTEGER,
                    win_rate DECIMAL(5, 4),
                    total_return_pct DECIMAL(10, 4),
                    roi DECIMAL(10, 4),
                    sharpe_ratio DECIMAL(10, 4),
                    max_drawdown_pct DECIMAL(10, 4),
                    profit_factor DECIMAL(10, 4),
                    avg_win_pct DECIMAL(10, 4),
                    avg_loss_pct DECIMAL(10, 4),
                    initial_capital DECIMAL(18, 2),
                    final_equity DECIMAL(18, 2),
                    net_profit DECIMAL(18, 2),
                    mlflow_run_id VARCHAR(255)
                );
            """)
            
            # Create index for backtest_results
            cur.execute("""
                CREATE INDEX IF NOT EXISTS idx_backtest_results_symbol_date 
                ON backtest_results (symbol, start_date, end_date);
            """)
            
            # Try to create hypertables if TimescaleDB is available
            try:
                cur.execute("SELECT create_hypertable('trades', 'timestamp', if_not_exists => TRUE);")
                logger.debug("Created hypertable for 'trades'")
            except Exception as e:
                logger.debug(f"Could not create hypertable for 'trades': {e}")
            
            try:
                cur.execute("SELECT create_hypertable('backtest_results', 'timestamp', if_not_exists => TRUE);")
                logger.debug("Created hypertable for 'backtest_results'")
            except Exception as e:
                logger.debug(f"Could not create hypertable for 'backtest_results': {e}")
            
            conn.commit()
            
            # Verify tables were created
            cur.execute("""
                SELECT table_name 
                FROM information_schema.tables 
                WHERE table_schema = 'public' 
                AND table_name IN ('trades', 'backtest_results');
            """)
            created_tables = [row[0] for row in cur.fetchall()]
            
            if 'trades' not in created_tables:
                logger.error("⚠️  WARNING: 'trades' table was not created!")
            if 'backtest_results' not in created_tables:
                logger.error("⚠️  WARNING: 'backtest_results' table was not created!")
            
            logger.info(f"TimescaleDB schema initialized. Tables: {created_tables}")
        except Exception as e:
            conn.rollback()
            logger.error(f"❌ Schema initialization error: {e}", exc_info=True)
            # Try to create tables individually with better error handling
            try:
                cur2 = conn.cursor()
                cur2.execute("""
                    CREATE TABLE IF NOT EXISTS trades (
                        id SERIAL PRIMARY KEY,
                        timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                        symbol VARCHAR(50) NOT NULL,
                        direction VARCHAR(10) NOT NULL,
                        quantity DECIMAL(18, 8) NOT NULL,
                        entry_price DECIMAL(18, 2) NOT NULL,
                        stop_loss DECIMAL(18, 2),
                        take_profit DECIMAL(18, 2),
                        atr_value DECIMAL(18, 2),
                        status VARCHAR(20) DEFAULT 'open',
                        exit_price DECIMAL(18, 2),
                        exit_timestamp TIMESTAMPTZ,
                        pnl DECIMAL(18, 2),
                        daily_trade_number INTEGER,
                        mlflow_run_id VARCHAR(255)
                    );
                """)
                conn.commit()
                cur2.close()
                logger.info("✓ Trades table created successfully (fallback)")
            except Exception as e2:
                logger.error(f"Failed to create trades table (fallback): {e2}", exc_info=True)
        finally:
            self.pool.putconn(conn)
    
    def ensure_backtest_results_table(self, cur, conn):
        """Ensure backtest_results table exists, create if it doesn't"""
        # Check if table exists
        cur.execute("""
            SELECT EXISTS (
                SELECT FROM information_schema.tables 
                WHERE table_schema = 'public' 
                AND table_name = 'backtest_results'
            );
        """)
        table_exists = cur.fetchone()[0]
        
        if not table_exists:
            logger.info("Creating backtest_results table...")
            cur.execute("""
                CREATE TABLE IF NOT EXISTS backtest_results (
                    id SERIAL PRIMARY KEY,
                    timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                    start_date DATE NOT NULL,
                    end_date DATE NOT NULL,
                    symbol VARCHAR(50) NOT NULL,
                    timeframe VARCHAR(10) NOT NULL,
                    total_trades INTEGER,
                    winning_trades INTEGER,
                    losing_trades INTEGER,
                    win_rate DECIMAL(5, 4),
                    total_return_pct DECIMAL(10, 4),
                    roi DECIMAL(10, 4),
                    sharpe_ratio DECIMAL(10, 4),
                    max_drawdown_pct DECIMAL(10, 4),
                    profit_factor DECIMAL(10, 4),
                    avg_win_pct DECIMAL(10, 4),
                    avg_loss_pct DECIMAL(10, 4),
                    initial_capital DECIMAL(18, 2),
                    final_equity DECIMAL(18, 2),
                    net_profit DECIMAL(18, 2),
                    mlflow_run_id VARCHAR(255)
                );
            """)
            
            # Create index for better query performance
            cur.execute("""
                CREATE INDEX IF NOT EXISTS idx_backtest_results_symbol_date 
                ON backtest_results (symbol, start_date, end_date);
            """)
            
            # Try to create hypertable if TimescaleDB is available
            try:
                cur.execute("SELECT create_hypertable('backtest_results', 'timestamp', if_not_exists => TRUE);")
                logger.info("✓ Created TimescaleDB hypertable for backtest_results")
            except Exception as e:
                logger.debug(f"Could not create hypertable (may be regular PostgreSQL): {e}")
            
            conn.commit()
            logger.info("✓ backtest_results table created successfully")
            return True
        else:
            logger.debug("backtest_results table already exists")
            return False
    
    def insert_backtest_results(self, metrics_data):
        """Insert backtest metrics into database"""
        if not self.pool:
            logger.warning("No database pool available, skipping backtest results insertion")
            return None
        
        conn = self.pool.getconn()
        try:
            cur = conn.cursor()
            # Set search_path to ensure we're using public schema
            cur.execute("SET search_path TO public;")
            
            # Create table if it doesn't exist (in same transaction)
            try:
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS public.backtest_results (
                        id SERIAL PRIMARY KEY,
                        timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                        start_date DATE NOT NULL,
                        end_date DATE NOT NULL,
                        symbol VARCHAR(50) NOT NULL,
                        timeframe VARCHAR(10) NOT NULL,
                        total_trades INTEGER,
                        winning_trades INTEGER,
                        losing_trades INTEGER,
                        win_rate DECIMAL(5, 4),
                        total_return_pct DECIMAL(10, 4),
                        roi DECIMAL(10, 4),
                        sharpe_ratio DECIMAL(10, 4),
                        max_drawdown_pct DECIMAL(10, 4),
                        profit_factor DECIMAL(10, 4),
                        avg_win_pct DECIMAL(10, 4),
                        avg_loss_pct DECIMAL(10, 4),
                        initial_capital DECIMAL(18, 2),
                        final_equity DECIMAL(18, 2),
                        net_profit DECIMAL(18, 2),
                        mlflow_run_id VARCHAR(255)
                    );
                """)
                
                cur.execute("""
                    CREATE INDEX IF NOT EXISTS idx_backtest_results_symbol_date 
                    ON public.backtest_results (symbol, start_date, end_date);
                """)
                
                # Commit table creation first
                conn.commit()
                logger.info("✓ Ensured backtest_results table exists")
                
                # Verify table exists by trying to query it
                cur.execute("SELECT COUNT(*) FROM public.backtest_results;")
                count = cur.fetchone()[0]
                logger.debug(f"Verified table exists - current row count: {count}")
            except Exception as e:
                conn.rollback()
                logger.warning(f"Table creation/verification failed: {e}")
                # Try to continue anyway - table might exist
            
            # Now insert the data
            try:
                cur.execute("""
                    INSERT INTO public.backtest_results (
                        start_date, end_date, symbol, timeframe,
                        total_trades, winning_trades, losing_trades, win_rate,
                        total_return_pct, roi, sharpe_ratio, max_drawdown_pct,
                        profit_factor, avg_win_pct, avg_loss_pct,
                        initial_capital, final_equity, net_profit, mlflow_run_id
                    ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
                    RETURNING id;
                """, (
                    metrics_data.get('start_date'),
                    metrics_data.get('end_date'),
                    metrics_data.get('symbol'),
                    metrics_data.get('timeframe'),
                    metrics_data.get('total_trades'),
                    metrics_data.get('winning_trades'),
                    metrics_data.get('losing_trades'),
                    metrics_data.get('win_rate'),
                    metrics_data.get('total_return_pct'),
                    metrics_data.get('roi'),
                    metrics_data.get('sharpe_ratio'),
                    metrics_data.get('max_drawdown_pct'),
                    metrics_data.get('profit_factor'),
                    metrics_data.get('avg_win_pct'),
                    metrics_data.get('avg_loss_pct'),
                    metrics_data.get('initial_capital'),
                    metrics_data.get('final_equity'),
                    metrics_data.get('net_profit'),
                    metrics_data.get('mlflow_run_id')
                ))
                result_id = cur.fetchone()[0]
                conn.commit()
                logger.info(f"✓ Backtest results inserted successfully with ID: {result_id}")
                sharpe = metrics_data.get('sharpe_ratio', 0) or 0
                roi = metrics_data.get('roi', 0) or 0
                logger.info(f"  Metrics: Trades={metrics_data.get('total_trades', 0)}, "
                           f"Sharpe={sharpe:.4f}, "
                           f"ROI={roi:.2f}%")
                logger.info(f"  Data available in Grafana via TimescaleDB table 'backtest_results'")
                return result_id
            except psycopg2.errors.UndefinedTable as e:
                # Table still doesn't exist after creation attempt
                conn.rollback()
                logger.error(f"Table 'backtest_results' not found after creation attempt: {e}")
                logger.error("This may indicate a schema or permissions issue")
                return None
            except Exception as e:
                conn.rollback()
                logger.error(f"Error inserting backtest results: {e}", exc_info=True)
                return None
        except Exception as e:
            logger.error(f"Error getting database connection: {e}", exc_info=True)
            return None
        finally:
            if conn:
                try:
                    self.pool.putconn(conn)
                except Exception as e:
                    logger.debug(f"Error returning connection: {e}")
    
    def _ensure_backtest_results_table_exists(self):
        """Ensure backtest_results table exists - uses separate connection"""
        if not self.pool:
            return
        
        conn = self.pool.getconn()
        try:
            # Set search_path to ensure we're using public schema
            cur = conn.cursor()
            cur.execute("SET search_path TO public;")
            
            # Check if table exists in public schema
            cur.execute("""
                SELECT EXISTS (
                    SELECT FROM information_schema.tables 
                    WHERE table_schema = 'public' 
                    AND table_name = 'backtest_results'
                );
            """)
            table_exists = cur.fetchone()[0]
            
            if not table_exists:
                logger.info("Creating backtest_results table in public schema...")
                # Explicitly create in public schema
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS public.backtest_results (
                        id SERIAL PRIMARY KEY,
                        timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                        start_date DATE NOT NULL,
                        end_date DATE NOT NULL,
                        symbol VARCHAR(50) NOT NULL,
                        timeframe VARCHAR(10) NOT NULL,
                        total_trades INTEGER,
                        winning_trades INTEGER,
                        losing_trades INTEGER,
                        win_rate DECIMAL(5, 4),
                        total_return_pct DECIMAL(10, 4),
                        roi DECIMAL(10, 4),
                        sharpe_ratio DECIMAL(10, 4),
                        max_drawdown_pct DECIMAL(10, 4),
                        profit_factor DECIMAL(10, 4),
                        avg_win_pct DECIMAL(10, 4),
                        avg_loss_pct DECIMAL(10, 4),
                        initial_capital DECIMAL(18, 2),
                        final_equity DECIMAL(18, 2),
                        net_profit DECIMAL(18, 2),
                        mlflow_run_id VARCHAR(255)
                    );
                """)
                
                cur.execute("""
                    CREATE INDEX IF NOT EXISTS idx_backtest_results_symbol_date 
                    ON public.backtest_results (symbol, start_date, end_date);
                """)
                
                try:
                    cur.execute("SELECT create_hypertable('public.backtest_results', 'timestamp', if_not_exists => TRUE);")
                    logger.info("✓ Created TimescaleDB hypertable for backtest_results")
                except Exception as e:
                    # Try without schema qualification
                    try:
                        cur.execute("SELECT create_hypertable('backtest_results', 'timestamp', if_not_exists => TRUE);")
                        logger.info("✓ Created TimescaleDB hypertable for backtest_results")
                    except Exception:
                        pass  # Not TimescaleDB or already exists
                
                # Commit the table creation
                conn.commit()
                logger.info("✓ backtest_results table created successfully and committed")
                
                # Small delay to ensure transaction is fully committed
                import time
                time.sleep(0.1)
            else:
                logger.debug("backtest_results table already exists in public schema")
        except Exception as e:
            conn.rollback()
            logger.error(f"Error ensuring table exists: {e}", exc_info=True)
            # Don't re-raise - let the insert attempt handle it
        finally:
            self.pool.putconn(conn)
    
    def insert_trade(self, trade_data):
        """Insert a trade record"""
        if not self.pool:
            logger.warning("No database pool available, skipping trade insertion")
            return None
        
        conn = self.pool.getconn()
        try:
            cur = conn.cursor()
            
            # Verify trades table exists
            cur.execute("""
                SELECT EXISTS (
                    SELECT FROM information_schema.tables 
                    WHERE table_schema = 'public' 
                    AND table_name = 'trades'
                );
            """)
            table_exists = cur.fetchone()[0]
            
            if not table_exists:
                logger.error("❌ 'trades' table does not exist! Creating it now...")
                cur.execute("""
                    CREATE TABLE IF NOT EXISTS trades (
                        id SERIAL PRIMARY KEY,
                        timestamp TIMESTAMPTZ NOT NULL DEFAULT NOW(),
                        symbol VARCHAR(50) NOT NULL,
                        direction VARCHAR(10) NOT NULL,
                        quantity DECIMAL(18, 8) NOT NULL,
                        entry_price DECIMAL(18, 2) NOT NULL,
                        stop_loss DECIMAL(18, 2),
                        take_profit DECIMAL(18, 2),
                        atr_value DECIMAL(18, 2),
                        status VARCHAR(20) DEFAULT 'open',
                        exit_price DECIMAL(18, 2),
                        exit_timestamp TIMESTAMPTZ,
                        pnl DECIMAL(18, 2),
                        daily_trade_number INTEGER,
                        mlflow_run_id VARCHAR(255)
                    );
                """)
                conn.commit()
                logger.info("✓ Trades table created")
            
            cur.execute("""
                INSERT INTO trades (
                    symbol, direction, quantity, entry_price, stop_loss, 
                    take_profit, atr_value, daily_trade_number, mlflow_run_id
                ) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
                RETURNING id;
            """, (
                trade_data.get('symbol'),
                trade_data.get('direction'),
                trade_data.get('quantity'),
                trade_data.get('entry_price'),
                trade_data.get('stop_loss'),
                trade_data.get('take_profit'),
                trade_data.get('atr_value'),
                trade_data.get('daily_trade_number'),
                trade_data.get('mlflow_run_id')
            ))
            trade_id = cur.fetchone()[0]
            conn.commit()
            return trade_id
        except Exception as e:
            conn.rollback()
            logger.error(f"Error inserting trade: {e}", exc_info=True)
            return None
        finally:
            self.pool.putconn(conn)

# Historical Data Collector
class HistoricalDataCollector:
    """Collect and store historical BTC data in TimescaleDB"""
    
    def __init__(self, db: TimescaleDB, api_key: str, api_secret: str):
        self.db = db
        self.data_client = CryptoHistoricalDataClient(api_key, api_secret)
        self._create_ohlcv_table()
    
    def _create_ohlcv_table(self):
        """Create OHLCV table in TimescaleDB"""
        if not self.db or not self.db.pool:
            logger.warning("Database connection pool not available. Cannot create OHLCV table.")
            logger.warning("Data collection will work, but data won't be stored in database.")
            return
        
        conn = self.db.pool.getconn()
        try:
            cur = conn.cursor()
            cur.execute("""
                CREATE TABLE IF NOT EXISTS ohlcv_data (
                    timestamp TIMESTAMPTZ NOT NULL,
                    symbol VARCHAR(20) NOT NULL,
                    timeframe VARCHAR(10) NOT NULL,
                    open DECIMAL(20, 8),
                    high DECIMAL(20, 8),
                    low DECIMAL(20, 8),
                    close DECIMAL(20, 8),
                    volume DECIMAL(30, 8),
                    PRIMARY KEY (timestamp, symbol, timeframe)
                );
            """)
            
            try:
                cur.execute("SELECT create_hypertable('ohlcv_data', 'timestamp', if_not_exists => TRUE);")
            except Exception:
                logger.warning("Could not create hypertable for ohlcv_data")
            
            cur.execute("""
                CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_timeframe 
                ON ohlcv_data (symbol, timeframe, timestamp DESC);
            """)
            
            conn.commit()
            logger.info("OHLCV table created/verified")
        except Exception as e:
            conn.rollback()
            logger.error(f"Error creating OHLCV table: {e}")
        finally:
            if self.db and self.db.pool:
                self.db.pool.putconn(conn)
    
    def _normalize_symbol(self, symbol: str) -> str:
        """Normalize symbol format (BTC/USD -> BTCUSD)"""
        # Convert BTC/USD format to BTCUSD
        return symbol.replace('/', '').replace('-', '').upper()
    
    def collect_historical_data(self, symbol: str, start_date: str, end_date: str, 
                               timeframe: str = '5Min'):
        """Collect historical data from Alpaca and store in TimescaleDB"""
        timeframe_map = {
            '1Min': TimeFrame(1, TimeFrameUnit.Minute),
            '5Min': TimeFrame(5, TimeFrameUnit.Minute),
            '15Min': TimeFrame(15, TimeFrameUnit.Minute),
            '1Hour': TimeFrame(1, TimeFrameUnit.Hour),
            '1Day': TimeFrame(1, TimeFrameUnit.Day)
        }
        
        tf = timeframe_map.get(timeframe, TimeFrame(5, TimeFrameUnit.Minute))
        
        start_dt = datetime.strptime(start_date, '%Y-%m-%d').replace(tzinfo=pytz.UTC)
        end_dt = datetime.strptime(end_date, '%Y-%m-%d').replace(tzinfo=pytz.UTC)
        
        # Normalize symbol for storage (BTC/USD -> BTCUSD)
        normalized_symbol = self._normalize_symbol(symbol)
        
        logger.info(f"Collecting {symbol} data from {start_date} to {end_date} ({timeframe})")
        logger.info(f"Storing as symbol: {normalized_symbol}")
        
        current_start = start_dt
        total_bars = 0
        
        while current_start < end_dt:
            chunk_end = min(current_start + timedelta(days=30), end_dt)
            
            try:
                request_params = CryptoBarsRequest(
                    symbol_or_symbols=[symbol],
                    timeframe=tf,
                    start=current_start,
                    end=chunk_end,
                    limit=1000
                )
                
                barset = self.data_client.get_crypto_bars(request_params)
                
                # Check what symbols Alpaca actually returned
                if barset and barset.data:
                    actual_symbols = list(barset.data.keys())
                    logger.debug(f"Alpaca returned symbols: {actual_symbols}")
                    # Use the first symbol returned by Alpaca if our symbol isn't found
                    if symbol not in barset.data:
                        if actual_symbols:
                            symbol_to_use = actual_symbols[0]
                            logger.info(f"Using symbol from Alpaca response: {symbol_to_use}")
                        else:
                            logger.warning(f"No data for {symbol} from {current_start} to {chunk_end}")
                            current_start = chunk_end
                            continue
                    else:
                        symbol_to_use = symbol
                else:
                    logger.warning(f"No data for {symbol} from {current_start} to {chunk_end}")
                    current_start = chunk_end
                    continue
                
                df = barset.df.reset_index()
                df = df.rename(columns={
                    'open': 'open',
                    'high': 'high',
                    'low': 'low',
                    'close': 'close',
                    'volume': 'volume'
                })
                df = df.set_index('timestamp')
                
                # Store in database with normalized symbol
                if not self.db or not self.db.pool:
                    logger.warning("Database connection not available. Skipping database storage.")
                    logger.info(f"Would store {len(df)} bars. Total: {total_bars}")
                    current_start = chunk_end
                    time.sleep(0.5)
                    continue
                
                conn = self.db.pool.getconn()
                try:
                    cur = conn.cursor()
                    for idx, row in df.iterrows():
                        cur.execute("""
                            INSERT INTO ohlcv_data 
                            (timestamp, symbol, timeframe, open, high, low, close, volume)
                            VALUES (%s, %s, %s, %s, %s, %s, %s, %s)
                            ON CONFLICT (timestamp, symbol, timeframe) 
                            DO UPDATE SET
                                open = EXCLUDED.open,
                                high = EXCLUDED.high,
                                low = EXCLUDED.low,
                                close = EXCLUDED.close,
                                volume = EXCLUDED.volume;
                        """, (
                            idx, normalized_symbol, timeframe,
                            float(row['open']), float(row['high']),
                            float(row['low']), float(row['close']),
                            float(row['volume'])
                        ))
                    conn.commit()
                    total_bars += len(df)
                    logger.info(f"Stored {len(df)} bars. Total: {total_bars}")
                except Exception as e:
                    conn.rollback()
                    logger.error(f"Error storing data: {e}")
                finally:
                    if self.db and self.db.pool:
                        self.db.pool.putconn(conn)
                
                current_start = chunk_end
                time.sleep(0.5)  # Rate limiting
                
            except Exception as e:
                logger.error(f"Error fetching data: {e}")
                time.sleep(5)
                continue
        
        logger.info(f"Collection complete! Total bars stored: {total_bars}")
        return total_bars
    
    def get_historical_data_from_db(self, symbol: str, start_date: str, 
                                    end_date: str, timeframe: str = '5Min') -> pd.DataFrame:
        """Retrieve historical data from TimescaleDB"""
        if not self.db or not self.db.pool:
            logger.error("Database connection not available. Cannot retrieve data from database.")
            return None
        
        conn = self.db.pool.getconn()
        try:
            # First, check what symbols actually exist in the database
            cur = conn.cursor()
            cur.execute("SELECT DISTINCT symbol FROM ohlcv_data LIMIT 10;")
            existing_symbols = [row[0] for row in cur.fetchall()]
            logger.info(f"Symbols found in database: {existing_symbols}")
            logger.info(f"Looking for symbol: {symbol} (normalized: {self._normalize_symbol(symbol)})")
            
            # Also check what timeframes exist for debugging
            cur.execute("SELECT DISTINCT timeframe FROM ohlcv_data;")
            existing_timeframes = [row[0] for row in cur.fetchall()]
            logger.info(f"Timeframes found in database: {existing_timeframes}")
            
            # Try to find matching symbol (normalized or original format)
            normalized_symbol = self._normalize_symbol(symbol)
            symbol_to_use = None
            
            # Check if normalized symbol exists
            if normalized_symbol in existing_symbols:
                symbol_to_use = normalized_symbol
                logger.info(f"✓ Found exact match: Using normalized symbol '{symbol_to_use}'")
            # Check if original symbol exists
            elif symbol in existing_symbols:
                symbol_to_use = symbol
                logger.info(f"✓ Found exact match: Using original symbol format '{symbol_to_use}'")
            # Try to find any symbol that matches when normalized
            else:
                for db_symbol in existing_symbols:
                    db_normalized = self._normalize_symbol(db_symbol)
                    if db_normalized == normalized_symbol:
                        symbol_to_use = db_symbol
                        logger.info(f"✓ Found normalized match: '{symbol}' -> '{symbol_to_use}' (DB has '{db_symbol}', normalized to '{db_normalized}')")
                        break
            
            if symbol_to_use is None:
                logger.warning(f"✗ No matching symbol found for '{symbol}' (normalized: '{normalized_symbol}')")
                logger.warning(f"Available symbols in DB: {existing_symbols}")
                return None
            
            # Convert date strings to proper format for PostgreSQL
            # Use PostgreSQL's date casting for reliability
            try:
                # Parse dates and ensure they're timezone-aware
                start_dt = pd.to_datetime(start_date)
                if start_dt.tz is None:
                    start_dt = start_dt.tz_localize('UTC')
                else:
                    start_dt = start_dt.tz_convert('UTC')
                
                end_dt = pd.to_datetime(end_date)
                if end_dt.tz is None:
                    end_dt = end_dt.tz_localize('UTC')
                else:
                    end_dt = end_dt.tz_convert('UTC')
                
                # Add one day to end_date to include the entire end day
                end_dt = end_dt + pd.Timedelta(days=1)
            except Exception as e:
                logger.error(f"Error parsing dates: {e}")
                return None
            
            query = """
                SELECT timestamp, open, high, low, close, volume
                FROM ohlcv_data
                WHERE symbol = %s 
                AND timeframe = %s
                AND timestamp >= %s 
                AND timestamp < %s
                ORDER BY timestamp ASC;
            """
            logger.info(f"Executing query: symbol='{symbol_to_use}', timeframe='{timeframe}', start={start_dt}, end={end_dt}")
            cur.execute(query, (symbol_to_use, timeframe, start_dt, end_dt))
            rows = cur.fetchall()
            colnames = [desc[0] for desc in cur.description]
            cur.close()
            
            if len(rows) == 0:
                logger.warning(f"No data found for '{symbol_to_use}' from {start_date} to {end_date} with timeframe {timeframe}")
                # Debug: Check if data exists for this symbol at all
                try:
                    cur2 = conn.cursor()
                    cur2.execute("SELECT MIN(timestamp), MAX(timestamp), COUNT(*) FROM ohlcv_data WHERE symbol = %s AND timeframe = %s;", 
                                (symbol_to_use, timeframe))
                    debug_info = cur2.fetchone()
                    cur2.close()
                    if debug_info and debug_info[2] > 0:
                        logger.info(f"Debug: Symbol '{symbol_to_use}' has {debug_info[2]} total bars")
                        logger.info(f"Debug: Available date range: {debug_info[0]} to {debug_info[1]}")
                        logger.info(f"Debug: Requested date range: {start_dt} to {end_dt}")
                        # Check if there's any overlap
                        if debug_info[0] and debug_info[1]:
                            if end_dt < debug_info[0] or start_dt > debug_info[1]:
                                logger.warning(f"⚠️  Requested date range ({start_date} to {end_date}) does NOT overlap with available data ({debug_info[0]} to {debug_info[1]})")
                            else:
                                logger.info(f"✓ Date ranges overlap, but no data returned. This might be a timezone or date format issue.")
                    else:
                        logger.warning(f"Debug: No data found for symbol '{symbol_to_use}' with timeframe '{timeframe}' at all!")
                except Exception as debug_e:
                    logger.error(f"Error in debug query: {debug_e}")
                return None
            
            df = pd.DataFrame(rows, columns=colnames)
            df['timestamp'] = pd.to_datetime(df['timestamp'])
            
            # Convert Decimal columns to float (PostgreSQL DECIMAL returns as Decimal type)
            numeric_columns = ['open', 'high', 'low', 'close', 'volume']
            for col in numeric_columns:
                if col in df.columns:
                    df[col] = df[col].astype(float)
            
            df.set_index('timestamp', inplace=True)
            logger.info(f"✓ Successfully retrieved {len(df)} bars from database for '{symbol_to_use}'")
            return df
        except Exception as e:
            logger.error(f"Error retrieving data from database: {e}", exc_info=True)
            return None
        finally:
            self.db.pool.putconn(conn)

# Base Trading Bot Logic
class BaseCryptoTradingBot:
    def __init__(self, symbol='BTCUSD', max_daily_trades=3, quantity=0.001, 
                 entry_strategy='or'):
        """
        entry_strategy options:
        - 'and': Require both sweep AND FVG (most restrictive, original)
        - 'or': Require sweep OR FVG (more trades)
        - 'sweep_only': Only trade on sweeps
        - 'fvg_only': Only trade on FVGs
        """
        self.symbol = symbol
        self.max_daily_trades = max_daily_trades
        self.daily_trades = 0
        self.last_trade_date = None
        self.atr_length = 14
        self.fvg_lookback = 3
        self.sweep_lookback = 3
        self.quantity = quantity
        self.stop_loss_pct = 0.02
        self.take_profit_pct = 0.04
        self.entry_strategy = entry_strategy.lower()  # 'and', 'or', 'sweep_only', 'fvg_only'
        
    def reset_daily_counter(self):
        """Reset the daily trade counter if it's a new day"""
        current_date = datetime.now().date()
        if self.last_trade_date is None or current_date > self.last_trade_date:
            logger.info(f"New trading day: {current_date}. Resetting daily trade counter.")
            self.daily_trades = 0
            self.last_trade_date = current_date

    def calculate_atr(self, df: pd.DataFrame, length: int = 14) -> pd.Series:
        """Calculate Average True Range"""
        high = df['high']
        low = df['low']
        close = df['close'].shift(1)
        
        tr1 = high - low
        tr2 = abs(high - close)
        tr3 = abs(low - close)
        
        tr = pd.DataFrame({'tr1': tr1, 'tr2': tr2, 'tr3': tr3}).max(axis=1)
        atr = tr.rolling(window=length).mean()
        
        return atr

    def detect_sweep(self, df: pd.DataFrame):
        """Detect liquidity sweeps"""
        swing_low = df['low'].rolling(window=self.sweep_lookback).min().shift(2)
        swing_high = df['high'].rolling(window=self.sweep_lookback).max().shift(2)
        
        bullish_sweep = (df['low'].shift(2) == swing_low) & \
                         (df['low'].shift(1) < df['low'].shift(2)) & \
                         (df['close'] > df['high'].shift(2))
        
        bearish_sweep = (df['high'].shift(2) == swing_high) & \
                         (df['high'].shift(1) > df['high'].shift(2)) & \
                         (df['close'] < df['low'].shift(2))
        
        return bullish_sweep, bearish_sweep

    def detect_fvg(self, df: pd.DataFrame):
        """Detect Fair Value Gaps (FVG)
        A bullish FVG occurs when the current bar's low is above the highest high of the previous N bars
        A bearish FVG occurs when the current bar's high is below the lowest low of the previous N bars
        """
        # Look at previous bars (excluding current bar) to find the range
        highest_high = df['high'].shift(1).rolling(window=self.fvg_lookback).max()
        lowest_low = df['low'].shift(1).rolling(window=self.fvg_lookback).min()
        
        # Bullish FVG: current low is above the highest high of previous bars (gap up)
        bullish_fvg = df['low'] > highest_high
        
        # Bearish FVG: current high is below the lowest low of previous bars (gap down)
        bearish_fvg = df['high'] < lowest_low
        
        return bullish_fvg, bearish_fvg

    def check_for_signals(self, df: pd.DataFrame):
        """Check for trading signals in the data"""
        if df is None or len(df) == 0:
            logger.warning(f"Insufficient data for signal detection. DataFrame is None or empty")
            return None
        
        min_required = self.fvg_lookback + self.sweep_lookback + 5
        if len(df) < min_required:
            logger.warning(f"Insufficient data for signal detection. Have {len(df)} bars, need at least {min_required}")
            return None
        
        # Make a deep copy to avoid SettingWithCopyWarning when modifying
        df = df.copy(deep=True)
        
        # Calculate indicators - use .assign() to avoid SettingWithCopyWarning
        atr_series = self.calculate_atr(df, self.atr_length)
        df = df.assign(atr=atr_series)
        bullish_sweep, bearish_sweep = self.detect_sweep(df)
        bullish_fvg, bearish_fvg = self.detect_fvg(df)
        
        # Get latest signals
        latest = df.iloc[-1]
        
        # Convert price to float (handle Decimal types from database)
        try:
            price_value = float(latest['close'])
        except (TypeError, ValueError):
            price_value = 0.0
        
        # Handle NaN ATR - convert to float and replace with 0 if NaN
        try:
            atr_value = float(latest['atr'])
        except (TypeError, ValueError):
            atr_value = 0.0
        
        if pd.isna(atr_value) or atr_value is None:
            atr_value = 0.0
        
        signals = {
            'bullish_sweep': bullish_sweep.iloc[-1] if len(bullish_sweep) > 0 else False,
            'bearish_sweep': bearish_sweep.iloc[-1] if len(bearish_sweep) > 0 else False,
            'bullish_fvg': bullish_fvg.iloc[-1] if len(bullish_fvg) > 0 else False,
            'bearish_fvg': bearish_fvg.iloc[-1] if len(bearish_fvg) > 0 else False,
            'price': price_value,
            'atr': atr_value
        }
        
        # Log all signals for debugging (can be reduced later if too verbose)
        # Only log if there's an actual signal to reduce verbosity
        if signals['bullish_sweep'] or signals['bearish_sweep'] or signals['bullish_fvg'] or signals['bearish_fvg']:
            logger.debug(f"Signal detected: Price=${signals['price']:.2f}, "
                        f"BullishSweep={signals['bullish_sweep']}, BearishSweep={signals['bearish_sweep']}, "
                        f"BullishFVG={signals['bullish_fvg']}, BearishFVG={signals['bearish_fvg']}")
        
        return signals

# Backtest Bot using TimescaleDB
class TimescaleDBBacktestBot(BaseCryptoTradingBot):
    """Backtest bot that uses historical data from TimescaleDB"""
    
    def __init__(self, db: TimescaleDB, data_collector: HistoricalDataCollector, **kwargs):
        super().__init__(**kwargs)
        self.db = db
        self.data_collector = data_collector
        self.backtest_results = []
        self.current_mlflow_run = None
        self.positions = []  # Track open positions
        self.equity_curve = []  # Track equity over time
        self.initial_capital = 10000.0  # Starting capital
        self.current_capital = self.initial_capital
    
    def execute_trade(self, direction: str, price: float, atr_value: float, timestamp=None):
        """Simulate trade execution for backtesting"""
        # Daily trade limit removed - allow unlimited trades during backtest
        # if self.daily_trades >= self.max_daily_trades:
        #     logger.info(f"Maximum daily trades reached: {self.daily_trades}/{self.max_daily_trades}")
        #     return
        
        # Validate inputs - convert to float first to handle Decimal types
        try:
            price = float(price)
        except (TypeError, ValueError):
            logger.warning(f"Invalid price type: {type(price)}, value: {price}, skipping trade")
            return
        
        if price is None or np.isnan(price) or price <= 0:
            logger.warning(f"Invalid price: {price}, skipping trade")
            return
        
        # Handle NaN ATR - convert to float first
        try:
            atr_value = float(atr_value) if atr_value is not None else 0.0
        except (TypeError, ValueError):
            atr_value = 0.0
        
        atr_valid = not np.isnan(atr_value) and atr_value > 0
        
        # Calculate stop loss and take profit
        if direction == 'long':
            stop_price = price * (1 - self.stop_loss_pct)
            if atr_valid:
                take_profit = min(
                    price * (1 + self.take_profit_pct),
                    price + (2 * atr_value)
                )
            else:
                take_profit = price * (1 + self.take_profit_pct)
        else:  # short
            stop_price = price * (1 + self.stop_loss_pct)
            if atr_valid:
                take_profit = max(
                    price * (1 - self.take_profit_pct),
                    price - (2 * atr_value)
                )
            else:
                take_profit = price * (1 - self.take_profit_pct)
        
        logger.info(f"BACKTEST: Executed {direction.upper()} trade: {self.quantity} {self.symbol} at ${price:.2f}")
        logger.info(f"  Stop Loss: ${stop_price:.2f}, Take Profit: ${take_profit:.2f}")
        
        # Track position
        position = {
            'direction': direction,
            'entry_price': price,
            'stop_loss': stop_price,
            'take_profit': take_profit,
            'quantity': self.quantity,
            'entry_time': timestamp or datetime.now(),
            'entry_capital': self.current_capital
        }
        self.positions.append(position)
        
        # Save to database
        trade_id = None
        if self.db:
            # Ensure ATR is a valid float (handle NaN)
            atr_for_db = float(atr_value) if (atr_value is not None and not np.isnan(atr_value)) else 0.0
            
            trade_id = self.db.insert_trade({
                'symbol': self.symbol,
                'direction': direction,
                'quantity': float(self.quantity),
                'entry_price': float(price),
                'stop_loss': float(stop_price),
                'take_profit': float(take_profit),
                'atr_value': atr_for_db,
                'daily_trade_number': self.daily_trades + 1,
                'mlflow_run_id': self.current_mlflow_run.info.run_id if self.current_mlflow_run else None
            })
        
        self.daily_trades += 1
        self.backtest_results.append({
            'direction': direction,
            'entry_price': price,
            'stop_loss': stop_price,
            'take_profit': take_profit,
            'trade_id': trade_id,
            'entry_time': timestamp or datetime.now()
        })
    
    def calculate_metrics(self, df: pd.DataFrame):
        """Calculate performance metrics including Sharpe ratio"""
        if len(self.backtest_results) == 0:
            return {}
        
        # Calculate returns for each trade
        returns = []
        closed_trades = []
        
        for i, trade in enumerate(self.backtest_results):
            entry_price = trade['entry_price']
            entry_time = trade['entry_time']
            
            # Find closest index (handle timestamp matching)
            try:
                if entry_time in df.index:
                    entry_idx = df.index.get_loc(entry_time)
                else:
                    # Find closest timestamp
                    entry_idx = df.index.searchsorted(entry_time)
                    if entry_idx >= len(df):
                        entry_idx = len(df) - 1
            except Exception:
                continue
            
            # Find exit (stop loss or take profit hit)
            direction = trade['direction']
            stop_loss = trade['stop_loss']
            take_profit = trade['take_profit']
            
            exit_price = None
            exit_reason = None
            
            # Check subsequent bars for exit
            for j in range(entry_idx + 1, min(entry_idx + 100, len(df))):
                bar = df.iloc[j]
                
                if direction == 'long':
                    # Check stop loss
                    if bar['low'] <= stop_loss:
                        exit_price = stop_loss
                        exit_reason = 'stop_loss'
                        break
                    # Check take profit
                    elif bar['high'] >= take_profit:
                        exit_price = take_profit
                        exit_reason = 'take_profit'
                        break
                else:  # short
                    # Check stop loss
                    if bar['high'] >= stop_loss:
                        exit_price = stop_loss
                        exit_reason = 'stop_loss'
                        break
                    # Check take profit
                    elif bar['low'] <= take_profit:
                        exit_price = take_profit
                        exit_reason = 'take_profit'
                        break
            
            # If no exit found, use last bar price
            if exit_price is None:
                exit_price = df.iloc[-1]['close']
                exit_reason = 'end_of_data'
            
            # Calculate return
            if direction == 'long':
                trade_return = (exit_price - entry_price) / entry_price
            else:  # short
                trade_return = (entry_price - exit_price) / entry_price
            
            returns.append(trade_return)
            closed_trades.append({
                'entry_price': entry_price,
                'exit_price': exit_price,
                'direction': direction,
                'return': trade_return,
                'exit_reason': exit_reason
            })
        
        if len(returns) == 0:
            return {}
        
        returns_array = np.array(returns)
        
        # Calculate metrics
        total_return = np.sum(returns_array)
        total_trades = len(returns)
        winning_trades = len([r for r in returns if r > 0])
        losing_trades = len([r for r in returns if r < 0])
        win_rate = winning_trades / total_trades if total_trades > 0 else 0
        
        avg_return = np.mean(returns_array)
        std_return = np.std(returns_array)
        
        # Sharpe ratio (annualized, assuming 252 trading days)
        # For 5-min bars: 288 bars per day, so annualization factor
        bars_per_day = 288  # 5-min bars in a day
        trading_days = len(df) / bars_per_day
        sharpe_ratio = (avg_return / std_return * np.sqrt(bars_per_day)) if std_return > 0 else 0
        
        # Maximum drawdown
        cumulative_returns = np.cumsum(returns_array)
        running_max = np.maximum.accumulate(cumulative_returns)
        drawdown = cumulative_returns - running_max
        max_drawdown = np.min(drawdown) if len(drawdown) > 0 else 0
        
        # Profit factor
        gross_profit = sum([r for r in returns if r > 0])
        gross_loss = abs(sum([r for r in returns if r < 0]))
        profit_factor = gross_profit / gross_loss if gross_loss > 0 else float('inf')
        
        # Average win/loss
        avg_win = np.mean([r for r in returns if r > 0]) if winning_trades > 0 else 0
        avg_loss = np.mean([r for r in returns if r < 0]) if losing_trades > 0 else 0
        
        # Final equity
        final_equity = self.initial_capital * (1 + total_return)
        
        metrics = {
            'total_trades': total_trades,
            'winning_trades': winning_trades,
            'losing_trades': losing_trades,
            'win_rate': win_rate,
            'total_return': total_return,
            'total_return_pct': total_return * 100,
            'avg_return': avg_return,
            'avg_return_pct': avg_return * 100,
            'sharpe_ratio': sharpe_ratio,
            'max_drawdown': max_drawdown,
            'max_drawdown_pct': max_drawdown * 100,
            'profit_factor': profit_factor,
            'avg_win': avg_win,
            'avg_win_pct': avg_win * 100,
            'avg_loss': avg_loss,
            'avg_loss_pct': avg_loss * 100,
            'initial_capital': self.initial_capital,
            'final_equity': final_equity,
            'net_profit': final_equity - self.initial_capital,
            'roi': (final_equity - self.initial_capital) / self.initial_capital * 100
        }
        
        return metrics
    
    def run_backtest_from_db(self, start_date: str, end_date: str, timeframe: str = '5Min'):
        """Run backtest using data from TimescaleDB"""
        logger.info(f"Starting backtest from {start_date} to {end_date}")
        logger.info(f"Requested symbol: {self.symbol}")
        
        # Try to get data - the function will automatically find the correct symbol format
        # The get_historical_data_from_db method handles symbol matching internally
        df = self.data_collector.get_historical_data_from_db(
            self.symbol, start_date, end_date, timeframe
        )
        
        # If not found, try with alternative symbol formats as fallback
        if df is None or len(df) == 0:
            logger.warning(f"Initial query with '{self.symbol}' returned no data. Trying alternative formats...")
            # Try alternative symbol formats
            alt_symbols = ['BTC/USD', 'BTCUSD', 'BTC-USD']
            for alt_symbol in alt_symbols:
                if alt_symbol != self.symbol:
                    logger.info(f"Trying alternative symbol format: {alt_symbol}")
                    df = self.data_collector.get_historical_data_from_db(
                        alt_symbol, start_date, end_date, timeframe
                    )
                    if df is not None and len(df) > 0:
                        logger.info(f"✓ Found data with symbol format: {alt_symbol}")
                        # Update self.symbol to match what was found
                        self.symbol = alt_symbol
                        break
        
        if df is None or len(df) == 0:
            logger.error("No data found in database. Run data collection first!")
            logger.error(f"Tried symbol: {self.symbol} and alternatives: BTC/USD, BTCUSD, BTC-USD")
            return None
        
        logger.info(f"Backtesting on {len(df)} bars from database")
        
        # Validate DataFrame has required columns
        required_cols = ['open', 'high', 'low', 'close', 'volume']
        missing_cols = [col for col in required_cols if col not in df.columns]
        if missing_cols:
            logger.error(f"DataFrame missing required columns: {missing_cols}")
            return None
        
        # Ensure DataFrame is properly formatted (iloc works with any index type)
        # Just verify the DataFrame is not empty and has valid data
        if df.empty:
            logger.error("DataFrame is empty after retrieval")
            return None
        
        # MLflow experiment
        mlflow.set_experiment("btc_backtest_db")
        self.current_mlflow_run = mlflow.start_run()
        
        try:
            mlflow.log_param("symbol", self.symbol)
            mlflow.log_param("start_date", start_date)
            mlflow.log_param("end_date", end_date)
            mlflow.log_param("timeframe", timeframe)
            mlflow.log_param("max_daily_trades", self.max_daily_trades)
            mlflow.log_param("stop_loss_pct", self.stop_loss_pct)
            mlflow.log_param("take_profit_pct", self.take_profit_pct)
            mlflow.log_param("initial_capital", self.initial_capital)
            mlflow.log_param("entry_strategy", self.entry_strategy)
            
            # Walk-forward backtest
            # Start from index that ensures we have enough lookback data
            min_lookback = max(50, self.fvg_lookback + self.sweep_lookback + 5)
            start_idx = min_lookback
            
            # Validate that we have enough data
            if len(df) < min_lookback + 1:
                logger.error(f"Insufficient data for backtest. Need at least {min_lookback + 1} bars, but have {len(df)}")
                return None
            
            logger.info(f"Starting backtest loop from index {start_idx} to {len(df)-1} (total bars: {len(df)})")
            
            for i in range(start_idx, len(df)):
                # Ensure we always have at least 50 bars, starting from 0 if needed
                lookback_start = max(0, i - 50)
                
                # Validate indices before slicing
                if lookback_start >= len(df) or i+1 > len(df):
                    logger.warning(f"Invalid indices at i={i}: lookback_start={lookback_start}, i+1={i+1}, df_len={len(df)}")
                    continue
                
                # Use iloc to slice, then explicitly copy to avoid any view issues
                try:
                    lookback_data = df.iloc[lookback_start:i+1].copy(deep=True)  # +1 to include current bar
                except Exception as e:
                    logger.warning(f"Error slicing DataFrame at index {i}: {e}")
                    continue
                
                # Validate the slice has data
                if len(lookback_data) == 0:
                    logger.warning(f"Empty DataFrame slice at index {i} (lookback_start={lookback_start}, i+1={i+1}, df_len={len(df)})")
                    continue
                
                # Validate required columns exist in slice
                if not all(col in lookback_data.columns for col in required_cols):
                    logger.warning(f"Missing required columns in slice at index {i}")
                    continue
                
                signals = self.check_for_signals(lookback_data)
                
                if signals:
                    current_timestamp = df.index[i]
                    
                    # Determine long entry based on entry strategy
                    long_signal = False
                    if self.entry_strategy == 'and':
                        # Original: require both sweep AND FVG
                        long_signal = signals['bullish_sweep'] and signals['bullish_fvg']
                    elif self.entry_strategy == 'or':
                        # More flexible: require sweep OR FVG
                        long_signal = (signals['bullish_sweep'] or signals['bullish_fvg'])
                    elif self.entry_strategy == 'sweep_only':
                        # Only trade on sweeps
                        long_signal = signals['bullish_sweep']
                    elif self.entry_strategy == 'fvg_only':
                        # Only trade on FVGs
                        long_signal = signals['bullish_fvg']
                    
                    # Determine short entry based on entry strategy
                    short_signal = False
                    if self.entry_strategy == 'and':
                        short_signal = signals['bearish_sweep'] and signals['bearish_fvg']
                    elif self.entry_strategy == 'or':
                        short_signal = (signals['bearish_sweep'] or signals['bearish_fvg'])
                    elif self.entry_strategy == 'sweep_only':
                        short_signal = signals['bearish_sweep']
                    elif self.entry_strategy == 'fvg_only':
                        short_signal = signals['bearish_fvg']
                    
                    if long_signal:
                        self.execute_trade('long', signals['price'], signals['atr'], current_timestamp)
                    elif short_signal:
                        self.execute_trade('short', signals['price'], signals['atr'], current_timestamp)
                
                # Progress logging
                if i % 100 == 0:
                    logger.info(f"Backtest progress: {i}/{len(df)} bars ({i/len(df)*100:.1f}%)")
            
            # Calculate metrics
            metrics = self.calculate_metrics(df)
            
            logger.info(f"\n{'='*60}")
            logger.info("BACKTEST RESULTS")
            logger.info(f"{'='*60}")
            logger.info(f"Total Trades: {metrics.get('total_trades', 0)}")
            logger.info(f"Winning Trades: {metrics.get('winning_trades', 0)}")
            logger.info(f"Losing Trades: {metrics.get('losing_trades', 0)}")
            logger.info(f"Win Rate: {metrics.get('win_rate', 0)*100:.2f}%")
            logger.info(f"\nPerformance Metrics:")
            logger.info(f"Total Return: {metrics.get('total_return_pct', 0):.2f}%")
            logger.info(f"ROI: {metrics.get('roi', 0):.2f}%")
            logger.info(f"Initial Capital: ${metrics.get('initial_capital', 0):,.2f}")
            logger.info(f"Final Equity: ${metrics.get('final_equity', 0):,.2f}")
            logger.info(f"Net Profit: ${metrics.get('net_profit', 0):,.2f}")
            logger.info(f"\nRisk Metrics:")
            logger.info(f"Sharpe Ratio: {metrics.get('sharpe_ratio', 0):.4f}")
            logger.info(f"Max Drawdown: {metrics.get('max_drawdown_pct', 0):.2f}%")
            logger.info(f"Profit Factor: {metrics.get('profit_factor', 0):.4f}")
            logger.info(f"Average Win: {metrics.get('avg_win_pct', 0):.2f}%")
            logger.info(f"Average Loss: {metrics.get('avg_loss_pct', 0):.2f}%")
            logger.info(f"{'='*60}\n")
            
            # Log to MLflow
            mlflow.log_metric("total_trades", metrics.get('total_trades', 0))
            mlflow.log_metric("winning_trades", metrics.get('winning_trades', 0))
            mlflow.log_metric("losing_trades", metrics.get('losing_trades', 0))
            mlflow.log_metric("win_rate", metrics.get('win_rate', 0))
            mlflow.log_metric("total_return_pct", metrics.get('total_return_pct', 0))
            mlflow.log_metric("roi", metrics.get('roi', 0))
            mlflow.log_metric("final_equity", metrics.get('final_equity', 0))
            mlflow.log_metric("net_profit", metrics.get('net_profit', 0))
            mlflow.log_metric("sharpe_ratio", metrics.get('sharpe_ratio', 0))
            mlflow.log_metric("max_drawdown_pct", metrics.get('max_drawdown_pct', 0))
            mlflow.log_metric("profit_factor", metrics.get('profit_factor', 0))
            mlflow.log_metric("avg_win_pct", metrics.get('avg_win_pct', 0))
            mlflow.log_metric("avg_loss_pct", metrics.get('avg_loss_pct', 0))
            
            # Store metrics in TimescaleDB for Grafana visualization
            if self.db:
                metrics_for_db = metrics.copy()
                metrics_for_db.update({
                    'start_date': start_date,
                    'end_date': end_date,
                    'symbol': self.symbol,
                    'timeframe': timeframe,
                    'mlflow_run_id': self.current_mlflow_run.info.run_id if self.current_mlflow_run else None
                })
                result_id = self.db.insert_backtest_results(metrics_for_db)
                if result_id:
                    logger.info(f"Backtest metrics saved to TimescaleDB with ID: {result_id}")
            
            return metrics
            
        finally:
            mlflow.end_run()

# Configuration
class BacktestConfig:
    def __init__(self):
        global logger  # Declare global logger at the start
        
        # Alpaca credentials
        self.ALPACA_API_KEY = os.getenv('ALPACA_API_KEY', 'PKSWFXHIT7WAESKFYXTTJ6DKUE')
        self.ALPACA_API_SECRET = os.getenv('ALPACA_API_SECRET', 'A4nDUtAxdWijWjmg4zPVXcPeciaKhfkzwJ2wVF4gS5sg')
        
        # TimescaleDB configuration
        # Note: If running outside Docker (e.g., Jupyter notebook), use 'localhost' instead of 'timescaledb'
        default_url = 'postgresql://rayhan:12102801Rr@timescaledb:5432/arafatdb'
        self.TIMESCALEDB_URL = os.getenv('TIMESCALEDB_URL', default_url)
        
        # Auto-detect if we should use localhost (running outside Docker)
        # Check if 'timescaledb' hostname is in the URL and suggest localhost alternative
        if 'timescaledb' in self.TIMESCALEDB_URL and not os.getenv('TIMESCALEDB_URL'):
            # Try to detect if we're running outside Docker
            import socket
            try:
                socket.gethostbyname('timescaledb')
                logger.debug("Hostname 'timescaledb' is resolvable - likely running in Docker")
            except socket.gaierror:
                logger.warning("⚠️  Hostname 'timescaledb' not resolvable - you may be running outside Docker")
                logger.warning("   Try setting environment variable: TIMESCALEDB_URL=postgresql://rayhan:12102801Rr@localhost:5432/arafatdb")
                logger.warning("   Or modify the connection string in BacktestConfig class")
        
        # MLflow configuration
        self.MLFLOW_TRACKING_URI = os.getenv('MLFLOW_TRACKING_URI', 'http://mlflow:5000')
        
        # Loki configuration
        self.LOKI_URL = os.getenv('LOKI_URL', 'http://loki:3100/loki/api/v1/push')
        
        # Initialize connections
        logger.info("="*60)
        logger.info("Initializing Database Connection")
        logger.info("="*60)
        self.db = TimescaleDB(self.TIMESCALEDB_URL)
        
        # Test connection
        if self.db.pool:
            if self.db.test_connection():
                logger.info("✓ Database is ready for use")
            else:
                logger.warning("⚠️  Database connection test failed, but continuing...")
        else:
            logger.error("❌ Database connection pool is None - database features will be disabled")
            logger.error("   Please check your TIMESCALEDB_URL environment variable or connection string")
            logger.error("   If running outside Docker, you may need to use 'localhost' instead of 'timescaledb'")
        
        mlflow.set_tracking_uri(self.MLFLOW_TRACKING_URI)
        
        # Setup logging with Loki (reconfigure logger with Loki settings)
        logger = setup_logging(
            level=logging.INFO,
            loki_url=self.LOKI_URL,
            loki_labels={'service': 'btc_backtest'}
        )

# Database connection test function
def test_database_connection(connection_string=None):
    """
    Test database connection independently.
    
    Usage:
        test_database_connection()  # Uses default from BacktestConfig
        test_database_connection('postgresql://user:pass@host:port/db')
    """
    if connection_string is None:
        # Use default from BacktestConfig
        default_url = 'postgresql://rayhan:12102801Rr@timescaledb:5432/arafatdb'
        connection_string = os.getenv('TIMESCALEDB_URL', default_url)
    
    print("="*60)
    print("Testing Database Connection")
    print("="*60)
    # Mask password in connection string for display
    try:
        from urllib.parse import urlparse
        parsed = urlparse(connection_string)
        if parsed.password:
            masked = f"{parsed.scheme}://{parsed.username}:****@{parsed.hostname}:{parsed.port or 5432}{parsed.path}"
        else:
            masked = connection_string
        print(f"Connection string: {masked}")
    except:
        print(f"Connection string: [masked]")
    print()
    
    try:
        # Test basic connection
        print("1. Testing basic connection...")
        test_conn = psycopg2.connect(connection_string)
        print("   ✓ Basic connection successful")
        
        # Test query
        print("2. Testing query execution...")
        cur = test_conn.cursor()
        cur.execute("SELECT version();")
        version = cur.fetchone()[0]
        print(f"   ✓ Query successful")
        print(f"   PostgreSQL version: {version[:60]}...")
        
        # Check connection count
        print("3. Checking connection statistics...")
        cur.execute("""
            SELECT 
                count(*) as total_connections,
                count(*) FILTER (WHERE state = 'active') as active_connections,
                count(*) FILTER (WHERE state = 'idle') as idle_connections,
                setting::int as max_connections
            FROM pg_stat_activity, pg_settings
            WHERE name = 'max_connections'
            GROUP BY setting;
        """)
        stats = cur.fetchone()
        if stats:
            total, active, idle, max_conn = stats
            print(f"   Total connections: {total}/{max_conn}")
            print(f"   Active: {active}, Idle: {idle}")
            if total >= max_conn * 0.9:
                print(f"   ⚠️  WARNING: Connection pool is nearly full!")
        
        # Test table access
        print("4. Testing table access...")
        cur.execute("""
            SELECT table_name 
            FROM information_schema.tables 
            WHERE table_schema = 'public' 
            AND table_name IN ('trades', 'backtest_results', 'ohlcv_data')
            ORDER BY table_name;
        """)
        tables = [row[0] for row in cur.fetchall()]
        if tables:
            print(f"   ✓ Found tables: {', '.join(tables)}")
        else:
            print(f"   ⚠️  No expected tables found (this is OK if first run)")
        
        cur.close()
        test_conn.close()
        
        print()
        print("="*60)
        print("✓ Database connection test PASSED")
        print("="*60)
        return True
        
    except psycopg2.OperationalError as e:
        error_msg = str(e)
        print(f"   ❌ Connection failed: {e}")
        if "too many clients" in error_msg.lower():
            print()
            print("   SOLUTIONS:")
            print("   1. Close other database connections")
            print("   2. Restart PostgreSQL/TimescaleDB service")
            print("   3. Wait for idle connections to timeout")
            print("   4. Increase max_connections in postgresql.conf")
        print()
        print("="*60)
        print("❌ Database connection test FAILED")
        print("="*60)
        return False
    except Exception as e:
        print(f"   ❌ Unexpected error: {e}")
        print()
        print("="*60)
        print("❌ Database connection test FAILED")
        print("="*60)
        return False

# Main execution
if __name__ == "__main__":
    config = BacktestConfig()
    
    # Create data collector
    collector = HistoricalDataCollector(
        db=config.db,
        api_key=config.ALPACA_API_KEY,
        api_secret=config.ALPACA_API_SECRET
    )
    
    # STEP 1: Collect historical data (UNCOMMENT AND RUN THIS FIRST!)
    print("="*60)
    print("STEP 1: BTC Historical Data Collection")
    print("="*60)
    print("\n⚠️  IMPORTANT: Uncomment the lines below to collect historical data FIRST!")
    print("This will take 30-60 minutes depending on date range.\n")
    
    # ⬇️ UNCOMMENT THE LINES BELOW TO COLLECT DATA ⬇️
    # collector.collect_historical_data(
    #     symbol='BTCUSD',
    #     start_date='2020-01-01',  # Start date
    #     end_date='2024-12-31',     # End date
    #     timeframe='5Min'           # 5-minute bars
    # )
    # ⬆️ UNCOMMENT THE LINES ABOVE TO COLLECT DATA ⬆️
    
    # STEP 2: Run backtest using collected data
    print("\n" + "="*60)
    print("STEP 2: BTC Backtest")
    print("="*60)
    print("\n⚠️  Make sure you've collected data first (Step 1)!\n")
    
    # Check if data exists before running backtest
    # Try both symbol formats (BTCUSD and BTC/USD)
    test_df = collector.get_historical_data_from_db(
        'BTCUSD', '2021-01-01', '2022-12-31', '5Min'
    )
    
    # If not found, try with slash format
    if test_df is None or len(test_df) == 0:
        logger.info("Trying with BTC/USD format...")
        test_df = collector.get_historical_data_from_db(
            'BTC/USD', '2021-01-01', '2022-12-31', '5Min'
        )
    
    if test_df is None or len(test_df) == 0:
        print("❌ ERROR: No data found in database!")
        print("\nPlease run data collection first:")
        print("1. Uncomment the collector.collect_historical_data() lines above")
        print("2. Run this cell again")
        print("3. Wait for data collection to complete (30-60 minutes)")
        print("4. Then run the backtest section below")
    else:
        print(f"✅ Found {len(test_df)} bars in database. Running backtest...\n")
        
        backtest_bot = TimescaleDBBacktestBot(
            db=config.db,
            data_collector=collector,
            symbol='BTCUSD',
            max_daily_trades=3,
            quantity=0.001,
            entry_strategy='or'  # Options: 'and', 'or', 'sweep_only', 'fvg_only'
        )
        
        # Run backtest (use same dates as data collection)
        metrics = backtest_bot.run_backtest_from_db(
            start_date='2021-01-01',
            end_date='2022-12-31',
            timeframe='5Min'
        )
        
        if metrics:
            print(f"\n{'='*60}")
            print("BACKTEST SUMMARY")
            print(f"{'='*60}")
            print(f"Total Trades: {metrics.get('total_trades', 0)}")
            print(f"Long Trades: {len([t for t in backtest_bot.backtest_results if t['direction'] == 'long'])}")
            print(f"Short Trades: {len([t for t in backtest_bot.backtest_results if t['direction'] == 'short'])}")
            print(f"\nSharpe Ratio: {metrics.get('sharpe_ratio', 0):.4f}")
            print(f"Total Return: {metrics.get('total_return_pct', 0):.2f}%")
            print(f"ROI: {metrics.get('roi', 0):.2f}%")
            print(f"Max Drawdown: {metrics.get('max_drawdown_pct', 0):.2f}%")
            print(f"Profit Factor: {metrics.get('profit_factor', 0):.4f}")
            print(f"Win Rate: {metrics.get('win_rate', 0)*100:.2f}%")
            print(f"{'='*60}")



2025-11-14 19:53:26,231 - INFO - Initializing Database Connection
2025-11-14 19:53:26,231 - INFO - Attempting to connect to TimescaleDB...
2025-11-14 19:53:26,237 - INFO - ✓ Database connection test successful
2025-11-14 19:53:26,242 - INFO - ✓ Connection pool created
2025-11-14 19:53:26,261 - INFO - TimescaleDB schema initialized. Tables: ['trades']
2025-11-14 19:53:26,261 - INFO - ✓ TimescaleDB connection pool created successfully
2025-11-14 19:53:26,262 - INFO - ✓ Database connection test successful
2025-11-14 19:53:26,262 - INFO - ✓ Database is ready for use
2025-11-14 19:53:26,264 - INFO - OHLCV table created/verified
STEP 1: BTC Historical Data Collection

⚠️  IMPORTANT: Uncomment the lines below to collect historical data FIRST!
This will take 30-60 minutes depending on date range.


STEP 2: BTC Backtest

⚠️  Make sure you've collected data first (Step 1)!

2025-11-14 19:53:26,334 - INFO - Symbols found in database: ['BTC/USD']
2025-11-14 19:53:26,336 - INFO - Looking for symbol