In [None]:
"""Binance OHLCV data loader with DuckDB storage."""

import asyncio
import aiohttp
import duckdb
import polars as pl
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Optional
from loguru import logger


@dataclass
class BinanceClient:
    """Async client for Binance REST API."""
    
    base_url: str = "https://api.binance.com"
    max_retries: int = 3
    _session: Optional[aiohttp.ClientSession] = field(default=None, init=False)
    
    async def __aenter__(self):
        timeout = aiohttp.ClientTimeout(total=30)
        self._session = aiohttp.ClientSession(timeout=timeout)
        return self
    
    async def __aexit__(self, *args):
        if self._session:
            await self._session.close()
    
    async def get_klines(
        self,
        symbol: str,
        timeframe: str = "1m",
        start_time: Optional[datetime] = None,
        end_time: Optional[datetime] = None,
        limit: int = 1000
    ) -> list[dict]:
        """Fetches OHLCV data from Binance.
        
        Args:
            symbol: Trading pair (e.g., 'BTCUSDT').
            timeframe: Candle interval ('1m', '5m', '15m', '1h', '4h', '1d').
            start_time: Start of the range.
            end_time: End of the range.
            limit: Max candles per request (max 1000).
        
        Returns:
            List of OHLCV dictionaries.
        """
        params = {
            "symbol": symbol,
            "interval": timeframe,
            "limit": limit
        }
        if start_time:
            params["startTime"] = int(start_time.timestamp() * 1000)
        if end_time:
            params["endTime"] = int(end_time.timestamp() * 1000)
        
        url = f"{self.base_url}/api/v3/klines"
        
        for attempt in range(self.max_retries):
            try:
                async with self._session.get(url, params=params) as resp:
                    if resp.status == 429:
                        retry_after = int(resp.headers.get("Retry-After", 60))
                        logger.warning(f"Rate limited, waiting {retry_after}s")
                        await asyncio.sleep(retry_after)
                        continue
                    
                    if resp.status != 200:
                        text = await resp.text()
                        raise Exception(f"Binance API error {resp.status}: {text}")
                    
                    data = await resp.json()
                    break
                    
            except aiohttp.ClientError as e:
                if attempt < self.max_retries - 1:
                    wait = 2 ** attempt
                    logger.warning(f"Request failed, retrying in {wait}s: {e}")
                    await asyncio.sleep(wait)
                else:
                    raise
        
        return [
            {
                "timestamp": datetime.fromtimestamp(int(k[6]) / 1000),
                "open": float(k[1]),
                "high": float(k[2]),
                "low": float(k[3]),
                "close": float(k[4]),
                "volume": float(k[7]),
                "trades": int(k[8]),
            }
            for k in data
        ]
    
    async def get_klines_range(
        self,
        symbol: str,
        timeframe: str,
        start_time: datetime,
        end_time: datetime,
    ) -> list[dict]:
        """Downloads all klines for a specified period with pagination.
        
        Args:
            symbol: Trading pair.
            timeframe: Candle interval.
            start_time: Start of the range.
            end_time: End of the range.
        
        Returns:
            List of all OHLCV dictionaries in the range.
        """
        all_klines = []
        current_start = start_time
        
        timeframe_ms = {
            "1m": 60_000,
            "5m": 300_000,
            "15m": 900_000,
            "1h": 3_600_000,
            "4h": 14_400_000,
            "1d": 86_400_000,
        }
        step = timedelta(milliseconds=timeframe_ms.get(timeframe, 60_000) * 1000)
        
        while current_start < end_time:
            klines = await self.get_klines(
                symbol=symbol,
                timeframe=timeframe,
                start_time=current_start,
                end_time=min(current_start + step, end_time),
                limit=1000
            )
            
            if not klines:
                break
            
            all_klines.extend(klines)
            current_start = klines[-1]["timestamp"] + timedelta(milliseconds=1)
            await asyncio.sleep(0.05)
            
            if len(all_klines) % 10000 == 0:
                logger.info(f"{symbol}: loaded {len(all_klines):,} candles...")
        
        return all_klines


@dataclass
class SpotStore:
    """DuckDB storage for OHLCV data."""
    
    db_path: Path
    _con: duckdb.DuckDBPyConnection = field(init=False)
    
    def __post_init__(self):
        self._con = duckdb.connect(str(self.db_path))
        self._ensure_tables()
    
    def _ensure_tables(self):
        existing = self._con.execute("""
            SELECT column_name FROM information_schema.columns 
            WHERE table_name = 'ohlcv'
        """).fetchall()
        existing_cols = {row[0] for row in existing}
        
        if existing_cols and "open_time" in existing_cols:
            logger.info("Migrating old schema...")
            self._con.execute("""
                CREATE TABLE ohlcv_new (
                    symbol VARCHAR NOT NULL,
                    timeframe VARCHAR NOT NULL,
                    timestamp TIMESTAMP NOT NULL,
                    open DOUBLE NOT NULL,
                    high DOUBLE NOT NULL,
                    low DOUBLE NOT NULL,
                    close DOUBLE NOT NULL,
                    volume DOUBLE NOT NULL,
                    trades INTEGER,
                    PRIMARY KEY (symbol, timeframe, timestamp)
                )
            """)
            self._con.execute("""
                INSERT INTO ohlcv_new 
                SELECT symbol, interval, open_time, open, high, low, close, quote_volume, trades
                FROM ohlcv
            """)
            self._con.execute("DROP TABLE ohlcv")
            self._con.execute("ALTER TABLE ohlcv_new RENAME TO ohlcv")
            logger.info("Migration complete")
        
        self._con.execute("""
            CREATE TABLE IF NOT EXISTS ohlcv (
                symbol VARCHAR NOT NULL,
                timeframe VARCHAR NOT NULL,
                timestamp TIMESTAMP NOT NULL,
                open DOUBLE NOT NULL,
                high DOUBLE NOT NULL,
                low DOUBLE NOT NULL,
                close DOUBLE NOT NULL,
                volume DOUBLE NOT NULL,
                trades INTEGER,
                PRIMARY KEY (symbol, timeframe, timestamp)
            )
        """)
        
        self._con.execute("""
            CREATE INDEX IF NOT EXISTS idx_ohlcv_symbol_time 
            ON ohlcv(symbol, timeframe, timestamp DESC)
        """)
        
        logger.info(f"Database initialized: {self.db_path}")
    
    def insert_klines(self, symbol: str, timeframe: str, klines: list[dict]):
        """Inserts klines with upsert logic.
        
        Args:
            symbol: Trading pair.
            timeframe: Candle interval.
            klines: List of OHLCV dictionaries.
        """
        if not klines:
            return
        
        if len(klines) <= 10:
            self._con.executemany(
                "INSERT OR REPLACE INTO ohlcv VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
                [
                    (symbol, timeframe, k["timestamp"], k["open"], k["high"], 
                     k["low"], k["close"], k["volume"], k["trades"])
                    for k in klines
                ]
            )
        else:
            df = pl.DataFrame({
                "symbol": [symbol] * len(klines),
                "timeframe": [timeframe] * len(klines),
                "timestamp": [k["timestamp"] for k in klines],
                "open": [k["open"] for k in klines],
                "high": [k["high"] for k in klines],
                "low": [k["low"] for k in klines],
                "close": [k["close"] for k in klines],
                "volume": [k["volume"] for k in klines],
                "trades": [k["trades"] for k in klines],
            })
            
            self._con.register("temp_klines", df.to_arrow())
            self._con.execute("INSERT OR REPLACE INTO ohlcv SELECT * FROM temp_klines")
            self._con.unregister("temp_klines")
        
        logger.debug(f"Inserted {len(klines):,} rows for {symbol} ({timeframe})")
    
    def get_time_bounds(self, symbol: str, timeframe: str) -> tuple[Optional[datetime], Optional[datetime]]:
        """Returns min and max timestamps for a symbol.
        
        Args:
            symbol: Trading pair.
            timeframe: Candle interval.
        
        Returns:
            Tuple of (min_timestamp, max_timestamp) or (None, None).
        """
        result = self._con.execute("""
            SELECT MIN(timestamp), MAX(timestamp) FROM ohlcv 
            WHERE symbol = ? AND timeframe = ?
        """, [symbol, timeframe]).fetchone()
        return (result[0], result[1]) if result and result[0] else (None, None)
    
    def find_gaps(
        self,
        symbol: str,
        timeframe: str,
        start: datetime,
        end: datetime,
        tf_minutes: int
    ) -> list[tuple[datetime, datetime]]:
        """Finds missing data gaps in the specified range.
        
        Args:
            symbol: Trading pair.
            timeframe: Candle interval.
            start: Start of the range.
            end: End of the range.
            tf_minutes: Timeframe duration in minutes.
        
        Returns:
            List of (gap_start, gap_end) tuples.
        """
        existing = self._con.execute("""
            SELECT timestamp FROM ohlcv 
            WHERE symbol = ? AND timeframe = ? AND timestamp BETWEEN ? AND ?
            ORDER BY timestamp
        """, [symbol, timeframe, start, end]).fetchall()
        
        if not existing:
            return [(start, end)]
        
        existing_times = {row[0] for row in existing}
        gaps = []
        gap_start = None
        current = start
        
        while current <= end:
            if current not in existing_times:
                if gap_start is None:
                    gap_start = current
            else:
                if gap_start is not None:
                    gaps.append((gap_start, current - timedelta(minutes=tf_minutes)))
                    gap_start = None
            current += timedelta(minutes=tf_minutes)
        
        if gap_start is not None:
            gaps.append((gap_start, end))
        
        return gaps
    
    def load(
        self,
        symbol: str,
        timeframe: str = "1m",
        hours: Optional[int] = None,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
    ) -> pl.DataFrame:
        """Loads OHLCV data into a Polars DataFrame.
        
        Args:
            symbol: Trading pair.
            timeframe: Candle interval.
            hours: Load last N hours (alternative to start/end).
            start: Start of the range.
            end: End of the range.
        
        Returns:
            Polars DataFrame with columns: timestamp, open, high, low, close, volume, trades.
        """
        query = """
            SELECT timestamp, open, high, low, close, volume, trades
            FROM ohlcv
            WHERE symbol = ? AND timeframe = ?
        """
        params = [symbol, timeframe]
        
        if hours:
            query += f" AND timestamp > NOW() - INTERVAL '{hours}' HOUR"
        elif start and end:
            query += " AND timestamp BETWEEN ? AND ?"
            params.extend([start, end])
        elif start:
            query += " AND timestamp >= ?"
            params.append(start)
        elif end:
            query += " AND timestamp <= ?"
            params.append(end)
        
        query += " ORDER BY timestamp"
        
        return self._con.execute(query, params).pl()
    
    def get_stats(self) -> pl.DataFrame:
        """Returns statistics for all pairs in database.
        
        Returns:
            DataFrame with row counts and time bounds per symbol.
        """
        return self._con.execute("""
            SELECT 
                symbol,
                timeframe,
                COUNT(*) as rows,
                MIN(timestamp) as first_candle,
                MAX(timestamp) as last_candle,
                ROUND(SUM(volume), 2) as total_volume
            FROM ohlcv
            GROUP BY symbol, timeframe
            ORDER BY symbol, timeframe
        """).pl()
    
    def close(self):
        """Closes the database connection."""
        self._con.close()


@dataclass
class BinanceSpotLoader:
    """Downloads and stores Binance spot OHLCV data."""
    
    db_path: Path = field(default_factory=lambda: Path("raw_data.duckdb"))
    
    async def download(
        self,
        pairs: list[str],
        timeframe: str = "1m",
        days: Optional[int] = None,
        start: Optional[datetime] = None,
        end: Optional[datetime] = None,
        fill_gaps: bool = True,
    ):
        """Downloads historical data for a list of pairs.
        
        Args:
            pairs: List of trading pairs (e.g., ['BTCUSDT', 'ETHUSDT']).
            timeframe: Candle interval.
            days: Number of days to download (alternative to start/end).
            start: Start of the range.
            end: End of the range (defaults to now).
            fill_gaps: Whether to fill gaps in existing data.
        """
        store = SpotStore(self.db_path)
        
        now = datetime.now(timezone.utc).replace(tzinfo=None)
        
        if end is None:
            end = now
        if start is None:
            if days:
                start = end - timedelta(days=days)
            else:
                start = end - timedelta(days=7)
        
        tf_minutes = {
            "1m": 1, "5m": 5, "15m": 15, 
            "1h": 60, "4h": 240, "1d": 1440
        }.get(timeframe, 1)
        
        async def download_pair(client: BinanceClient, pair: str):
            logger.info(f"Processing {pair} from {start} to {end}")
            
            db_min, db_max = store.get_time_bounds(pair, timeframe)
            ranges_to_download = []
            
            if db_min is None:
                ranges_to_download.append((start, end))
            else:
                if start < db_min:
                    ranges_to_download.append((start, db_min - timedelta(minutes=tf_minutes)))
                
                if end > db_max:
                    ranges_to_download.append((db_max + timedelta(minutes=tf_minutes), end))
                
                if fill_gaps:
                    overlap_start = max(start, db_min)
                    overlap_end = min(end, db_max)
                    if overlap_start < overlap_end:
                        gaps = store.find_gaps(pair, timeframe, overlap_start, overlap_end, tf_minutes)
                        ranges_to_download.extend(gaps)
            
            for range_start, range_end in ranges_to_download:
                if range_start >= range_end:
                    continue
                    
                logger.info(f"{pair}: downloading {range_start} -> {range_end}")
                
                try:
                    klines = await client.get_klines_range(
                        symbol=pair,
                        timeframe=timeframe,
                        start_time=range_start,
                        end_time=range_end,
                    )
                    store.insert_klines(pair, timeframe, klines)
                    
                except Exception as e:
                    logger.error(f"Error downloading {pair}: {e}")
        
        async with BinanceClient() as client:
            await asyncio.gather(*[download_pair(client, pair) for pair in pairs])
        
        logger.info("=" * 60)
        print(store.get_stats())
        
        store.close()
    
    async def sync(
        self,
        pairs: list[str],
        timeframe: str = "1m",
        update_interval_sec: int = 60,
    ):
        """Real-time sync that updates latest candles periodically.
        
        Args:
            pairs: List of trading pairs.
            timeframe: Candle interval.
            update_interval_sec: Seconds between updates.
        """
        store = SpotStore(self.db_path)
        
        logger.info(f"Starting real-time sync for {pairs}")
        logger.info(f"Update interval: {update_interval_sec}s")
        
        async def fetch_and_store(client: BinanceClient, pair: str):
            try:
                klines = await client.get_klines(symbol=pair, timeframe=timeframe, limit=5)
                store.insert_klines(pair, timeframe, klines)
            except Exception as e:
                logger.error(f"Error syncing {pair}: {e}")
        
        async with BinanceClient() as client:
            while True:
                await asyncio.gather(*[fetch_and_store(client, pair) for pair in pairs])
                logger.debug(f"Synced {len(pairs)} pairs")
                await asyncio.sleep(update_interval_sec)

In [None]:
loader = BinanceSpotLoader(db_path=Path("raw_data3.duckdb"))

await loader.download(
    pairs=["BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT", "BNBUSDT"],
    timeframe="1m",
    days=10,
    fill_gaps=False
)