# Install Libraries

In [None]:
!pip install ccxt -qqq
!pip install yfinance -qqq
!pip install mplfinance -qqq

In [None]:
!pip install tensorflow -qqq
!pip install keras -qqq

In [None]:
!pip install gymnasium -qqq
!pip install stable-baselines3 -qqq
!pip install stable-baselines3[extra] -qqq
!pip install sb3_contrib -qqq

# Import Libraries

In [None]:
import sys
import os
import time
import numpy as np
import pandas as pd
import yfinance as yf
import statsmodels.api as sm
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from matplotlib.animation import FuncAnimation
from numpy.linalg import norm
from scipy.stats import entropy
from scipy.stats import entropy
from sklearn.cluster import Birch
from sklearn.preprocessing import MinMaxScaler
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_absolute_error, mean_squared_error

import ccxt
import logging
from pathlib import Path
from typing import List, Optional, Union

np.random.seed(0)

# Load Data


From any ccxt supported exchange.

coinbase, gemini, kraken, binanceus, etc

Recommended => Coinbase.



## CCXT

In [None]:
INSTRUMENT = "SOL/USDT"
TIMEFRAME = "4h"
EXCHANGE_ID = "binanceus"

In [None]:
# Instantiate the Coinbase exchange
exchange: ccxt.Exchange = getattr(ccxt, EXCHANGE_ID)()

# Load the markets to get exchange information, including timeframes
exchange.load_markets()

# Get and print the supported timeframes
exchange_timeframes = exchange.timeframes

if exchange_timeframes:
    print(f"Supported timeframes for {EXCHANGE_ID.capitalize()}:")
    for timeframe in exchange_timeframes:
        print(timeframe)
else:
    print(f"Could not retrieve timeframes for {EXCHANGE_ID.capitalize()}.")

Supported timeframes for Binanceus:
1s
1m
3m
5m
15m
30m
1h
2h
4h
6h
8h
12h
1d
3d
1w
1M


In [None]:
# Configure logger
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Add a log message to see output
logger.info("Logger configured successfully.")

def fetch_ohlcv_with_retries(exchange: ccxt.Exchange, symbol: str, timeframe: str, since: int, limit: int, max_retries: int = 3) -> List[List[Union[int, float]]]:
    """Fetch OHLCV data with retry logic."""
    for attempt in range(max_retries):
        try:
            return exchange.fetch_ohlcv(symbol, timeframe, since, limit)
        except Exception as e:
            if attempt == max_retries - 1:
                logger.error(f"Failed to fetch {timeframe} {symbol} OHLCV after {max_retries} attempts: {e}")
                raise
    return []

def load_existing_data(filename: Path) -> pd.DataFrame:
    """Load existing OHLCV data if available."""
    if filename.exists():
        return pd.read_csv(filename, parse_dates=["timestamp"], index_col="timestamp")
    return pd.DataFrame(columns=["open", "high", "low", "close", "volume"])

def scrape_ohlcv(exchange: ccxt.Exchange, symbol: str, timeframe: str, since: int, until: int, limit: int, max_retries: int = 3) -> List[List[Union[int, float]]]:
    """Scrape historical OHLCV data from an exchange between two dates."""
    all_ohlcv: List[List[Union[int, float]]] = []

    while since < until:
        ohlcv: List[List[Union[int, float]]] = fetch_ohlcv_with_retries(exchange, symbol, timeframe, since, limit, max_retries)

        if not ohlcv:
            break

        since = ohlcv[-1][0] + 1  # Move forward in time
        all_ohlcv.extend(ohlcv)
        logger.info(f"{len(all_ohlcv)} {symbol} candles collected from {exchange.iso8601(all_ohlcv[0][0])} to {exchange.iso8601(all_ohlcv[-1][0])}")

    return all_ohlcv

def save_to_csv(filename: Path, data: pd.DataFrame) -> None:
    """Save OHLCV data to a CSV file, appending new data if necessary."""
    if filename.exists():
        data.to_csv(filename, mode='a', header=False)
    else:
        data.to_csv(filename)
    logger.info(f"Data saved to {filename}")

def scrape_and_save_candles(exchange_id: str, symbol: str, timeframe: str, since: Union[int, str], until: Union[int, str], limit: int, max_retries: int = 3, filename: Optional[str] = None, exchange_options: Optional[dict] = None) -> None:
    """Scrape OHLCV data and save to a CSV file, supporting resuming downloads."""

    if filename is None or len(filename) == 0:
        filename = f"{symbol.replace('/', '_')}_{timeframe}.csv".lower()

    exchange_options = exchange_options or {}
    exchange: ccxt.Exchange = getattr(ccxt, exchange_id)({'enableRateLimit': True, 'options': exchange_options})

    if isinstance(since, str):
        since = exchange.parse8601(since)
    if not until:
        until = exchange.milliseconds()
    elif isinstance(until, str):
        until = exchange.parse8601(until)

    exchange.load_markets()
    file_path = Path("./data/ccxt/") / exchange_id / filename
    file_path.parent.mkdir(parents=True, exist_ok=True)
    existing_data = load_existing_data(file_path)

    if not existing_data.empty:
        last_timestamp = existing_data.index[-1].timestamp() * 1000  # Convert to ms
        if last_timestamp > since:
            since = int(last_timestamp) + 1  # Resume from the next candle

    ohlcv = scrape_ohlcv(exchange, symbol, timeframe, since, until, limit, max_retries)

    if ohlcv:
        new_data = pd.DataFrame(ohlcv, columns=["timestamp", "open", "high", "low", "close", "volume"])
        new_data["timestamp"] = pd.to_datetime(new_data["timestamp"], unit='ms')
        if not new_data.empty:
            if not existing_data.empty:
                existing_data.reset_index(inplace=True)
                combined_data = pd.concat([existing_data, new_data]).drop_duplicates(subset=["timestamp"]).sort_values("timestamp")
            else:
                combined_data = new_data.drop_duplicates(subset=["timestamp"]).sort_values("timestamp")
            save_to_csv(file_path, combined_data)
            logger.info(f"Saved {len(new_data)} new candles from {new_data.iloc[0, 0]} to {new_data.iloc[-1, 0]} to {filename}")
        else:
            logger.warning("No new OHLCV data to save.")
    else:
        logger.warning("No new OHLCV data retrieved.")


In [None]:
EXCHANGE_ID, INSTRUMENT, TIMEFRAME

('binanceus', 'SOL/USDT', '4h')

In [None]:
scrape_and_save_candles(exchange_id=EXCHANGE_ID, symbol=INSTRUMENT, timeframe=TIMEFRAME,
                            since="2019-06-01T00:00:00Z", until="2025-06-19T23:59:59Z", limit=1000)

# Exchange options:
# scrape_and_save_candles("binance", "BTC/USDT", "4h", "2011-01-01T00:00:00Z", "2023-12-01T00:00:00Z", 1000, exchange_options={'defaultType': 'future'})

In [None]:
ohlcv = pd.read_csv(f"./data/ccxt/{EXCHANGE_ID}/{INSTRUMENT.replace('/', '_')}_{TIMEFRAME}.csv".lower(), parse_dates=["timestamp"])
if "Unnamed: 0" in ohlcv.columns:
    ohlcv.drop(columns=["Unnamed: 0"], inplace=True)

ohlcv


Unnamed: 0,timestamp,open,high,low,close,volume
0,2020-09-18 12:00:00,3.0887,3.1355,2.8178,2.8929,5938.630
1,2020-09-18 16:00:00,2.9105,3.1543,2.8191,3.1487,9460.390
2,2020-09-18 20:00:00,3.1429,3.1490,3.0340,3.0994,1170.880
3,2020-09-19 00:00:00,3.0960,3.2430,3.0946,3.1240,3186.390
4,2020-09-19 04:00:00,3.1298,3.1453,3.0708,3.1044,327.100
...,...,...,...,...,...,...
10405,2025-06-19 00:00:00,146.4400,147.7400,145.0000,146.4300,100.995
10406,2025-06-19 04:00:00,146.6700,147.2600,145.1400,145.2900,75.422
10407,2025-06-19 08:00:00,145.7800,146.4800,144.9200,145.1800,251.374
10408,2025-06-19 12:00:00,145.2500,145.4300,143.1800,143.5000,435.471


## Yahoo Finance


No enough data for the following:
- 'PI35697-USD'


**Indices for Alpha Generation:**
- USD Dollar Index (DXY/USDX/DX-Y.NYB)
- Trade-Weighted Dollar Index (DTWEXBGS)

*Why they matter for alpha generation:*
- They are proxies of **USD strength**, impacting global liquidity, capital flows, and inflation expectations.
- many risk assets, commodities, emerging markets, and cryptocurrencies are inversely correlated withthe dollar.
- Useful for both directional strategies and regime-based filters.




In [None]:
# # Fetch AAPL data
# advpp_data = yf.download('AAPL', start='2020-01-01', end='2024-01-01')

# # Display the first few rows of the dataframe
# advpp_data.head()

# --- Step 1: Fetch BTC/USD 15-min data for the last 7 days ---
# ticker="BTC-USD"
# interval="15m"
# ticker_data = yf.download(ticker, interval=interval, period="7d")
# ticker_data.dropna(inplace=True)
# ticker_data.columns = [col.lower() for col in ticker_data.columns]
# ticker_data

# ---- Parameters ----
indices = ['^GSPC', '^DJI', '^IXIC', '^RUT', '^VIX', 'DX-Y.NYB']
stocks = ['AAPL', 'AMZN', 'MSFT', 'NVDA', 'TSLA', 'INTC', 'QUBT', 'COIN', 'META']
fx = ['EURUSD=X', 'GBPUSD=X', 'USDJPY=X', 'USDCHF=X', 'EURGBP=X', 'EURNZD=X', 'EURCHF=X', 'GBPNZD=X', 'GBPCHF=X']
metals = ['^XAU', '^XAG']
commodities = ['CL=F']
crypto = ['BTC-USD', 'ETH-USD', 'SOL-USD', 'XRP-USD', 'LTC-USD', 'AAVE-USD', 'BNB-USD', 'TRX-USD', 'DOT-USD', 'XLM-USD']
symbols = indices + stocks + fx + metals + commodities + crypto
start_date = '2018-06-01'
end_date = '2025-06-19'
interval = '1d'

# Create a directory to store data if it doesn't exist
data_dir = './data/yahoo'
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Fetch and save data for each symbol individually
for symbol in symbols:
    try:
        print(f"Fetching data for {symbol}...")
        df = yf.download(symbol, start=start_date, end=end_date, interval=interval)
        if not df.empty:
            df.columns = df.columns.droplevel(1)
            filename = os.path.join(data_dir, f'{symbol.replace("-", "_").replace("=", "")}.csv')
            df.to_csv(filename)
            print(f"Saved data for {symbol} to {filename}")
        else:
            print(f"No data found for {symbol}")
    except Exception as e:
        print(f"Error fetching data for {symbol}: {e}")


Fetching data for ^GSPC...


  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)
[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^GSPC to ./data/yahoo/^GSPC.csv
Fetching data for ^DJI...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^DJI to ./data/yahoo/^DJI.csv
Fetching data for ^IXIC...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^IXIC to ./data/yahoo/^IXIC.csv
Fetching data for ^RUT...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^RUT to ./data/yahoo/^RUT.csv
Fetching data for ^VIX...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^VIX to ./data/yahoo/^VIX.csv
Fetching data for DX-Y.NYB...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for DX-Y.NYB to ./data/yahoo/DX_Y.NYB.csv
Fetching data for AAPL...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for AAPL to ./data/yahoo/AAPL.csv
Fetching data for AMZN...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for AMZN to ./data/yahoo/AMZN.csv
Fetching data for MSFT...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for MSFT to ./data/yahoo/MSFT.csv
Fetching data for NVDA...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for NVDA to ./data/yahoo/NVDA.csv
Fetching data for TSLA...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for TSLA to ./data/yahoo/TSLA.csv
Fetching data for INTC...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for INTC to ./data/yahoo/INTC.csv
Fetching data for QUBT...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for QUBT to ./data/yahoo/QUBT.csv
Fetching data for COIN...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for COIN to ./data/yahoo/COIN.csv
Fetching data for META...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for META to ./data/yahoo/META.csv
Fetching data for EURUSD=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for EURUSD=X to ./data/yahoo/EURUSDX.csv
Fetching data for GBPUSD=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for GBPUSD=X to ./data/yahoo/GBPUSDX.csv
Fetching data for USDJPY=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for USDJPY=X to ./data/yahoo/USDJPYX.csv
Fetching data for USDCHF=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for USDCHF=X to ./data/yahoo/USDCHFX.csv
Fetching data for EURGBP=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for EURGBP=X to ./data/yahoo/EURGBPX.csv
Fetching data for EURNZD=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for EURNZD=X to ./data/yahoo/EURNZDX.csv
Fetching data for EURCHF=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for EURCHF=X to ./data/yahoo/EURCHFX.csv
Fetching data for GBPNZD=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for GBPNZD=X to ./data/yahoo/GBPNZDX.csv
Fetching data for GBPCHF=X...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for GBPCHF=X to ./data/yahoo/GBPCHFX.csv
Fetching data for ^XAU...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ^XAU to ./data/yahoo/^XAU.csv
Fetching data for ^XAG...


[*********************100%***********************]  1 of 1 completed
ERROR:yfinance:
1 Failed download:
ERROR:yfinance:['^XAG']: YFPricesMissingError('possibly delisted; no price data found  (1d 2019-06-01 -> 2025-06-18)')
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


No data found for ^XAG
Fetching data for CL=F...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for CL=F to ./data/yahoo/CLF.csv
Fetching data for BTC-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for BTC-USD to ./data/yahoo/BTC_USD.csv
Fetching data for ETH-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for ETH-USD to ./data/yahoo/ETH_USD.csv
Fetching data for SOL-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for SOL-USD to ./data/yahoo/SOL_USD.csv
Fetching data for XRP-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for XRP-USD to ./data/yahoo/XRP_USD.csv
Fetching data for LTC-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for LTC-USD to ./data/yahoo/LTC_USD.csv
Fetching data for AAVE-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for AAVE-USD to ./data/yahoo/AAVE_USD.csv
Fetching data for BNB-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for BNB-USD to ./data/yahoo/BNB_USD.csv
Fetching data for TRX-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for TRX-USD to ./data/yahoo/TRX_USD.csv
Fetching data for DOT-USD...


[*********************100%***********************]  1 of 1 completed
  df = yf.download(symbol, start=start_date, end=end_date, interval=interval)


Saved data for DOT-USD to ./data/yahoo/DOT_USD.csv
Fetching data for XLM-USD...


[*********************100%***********************]  1 of 1 completed

Saved data for XLM-USD to ./data/yahoo/XLM_USD.csv





In [None]:

def get_yahoo_data(symbol):
    """
    Loads historical data for a given symbol from a CSV file,
    cleans and validates it.
    """
    # Replace characters that might cause issues in filenames
    filename_symbol = symbol.replace("-", "_").replace("=", "")
    filepath = os.path.join(data_dir, f'{filename_symbol}.csv')

    try:
        # Load the data from the individual CSV file
        df = pd.read_csv(filepath, parse_dates=['Date'], index_col='Date')
        print(f"\nLoaded data for {symbol} from {filepath}")
        print(df.head()) # Print the head of the loaded DataFrame (df)
    except FileNotFoundError:
        print(f"Error: Data file for {symbol} not found at {filepath}.")
        return pd.DataFrame() # Return empty DataFrame if file not found
    except Exception as e:
        print(f"Error loading or parsing data for {symbol} from {filepath}: {e}")
        return pd.DataFrame() # Return empty DataFrame on other errors

    # Convert relevant columns to numeric, coercing errors
    # Apply to_numeric to the entire DataFrame or selected columns after renaming
    # Note: yfinance typically provides these as floats, but this is a good safeguard
    for col in ['Open', 'High', 'Low', 'Close', 'Volume']:
        if col in df.columns:
            df[col] = pd.to_numeric(df[col], errors='coerce')
        else:
            print(f"Warning: Column '{col}' not found in data for {symbol}.")


    # Optional: Drop rows where all critical values are NaN
    # Ensure the columns exist before trying to drop based on them
    critical_cols = [col for col in ['Close', 'High', 'Low'] if col in df.columns]
    if critical_cols:
        df.dropna(subset=critical_cols, how='all', inplace=True)
    else:
         # If critical columns are missing, the dataframe is likely not useful
         print(f"Warning: Critical columns (Close, High, Low) missing for {symbol}. Returning empty DataFrame.")
         return pd.DataFrame()


    # Require a minimum number of data points after cleaning
    if df.empty or len(df) < 10:
        print(f"Warning: Not enough valid data points (less than 10) for symbol {symbol} after cleaning. Found {len(df)} rows.")
        return pd.DataFrame()

    return df


## Extract Price Series


## ✅ **Proposal**

We want to construct a custom price like:

$$
\text{Price}_{\text{custom}} = w_1 \cdot \text{High}_t + w_2 \cdot \text{Low}_t + w_3 \cdot \text{Close}_t
$$

Where:

* $w_1, w_2, w_3$ are weights, possibly adaptive or fixed.
* This synthetic series is then used as the "price" input into your features, indicators, or models.

---

## 🔍 Why This Can Be Powerful

Most indicators use **close-only**, which discards valuable intraday range information. A weighted blend:

* Captures **intrabar structure**.
* Smooths volatility.
* Can encode **market sentiment** better than raw closes.

---

## 📐 Common Weighted Price Schemes

Here are some **existing techniques** that can inspire or be combined with your idea:

### 1. **Typical Price**

$$
\text{TP} = \frac{High + Low + Close}{3}
$$

### 2. **Weighted Close**

$$
\text{WC} = \frac{High + Low + 2 \cdot Close}{4}
$$

### 3. **OHLC Average**

$$
\text{OHLC} = \frac{Open + High + Low + Close}{4}
$$

### 4. **Mid Price**

$$
\text{Mid} = \frac{High + Low}{2}
$$

### 5. **Custom Weights (Your Idea)**

$$
\text{P}_{\text{custom}} = \alpha \cdot H + \beta \cdot L + \gamma \cdot C \quad \text{where} \quad \alpha + \beta + \gamma = 1
$$

You can learn these weights in a model, optimize for Sharpe, or define them heuristically.

---

## 💡 How to Choose the Weights

### 1. **Heuristic**

Try:

* $\alpha = 0.25, \beta = 0.25, \gamma = 0.5$
* $\gamma = 1.0$ (use close-only as baseline)
* Use volatility to scale H/L contributions.

### 2. **Machine Learning / Optimization**

* Use **Sharpe-ratio maximization** to learn optimal weights.
* Or use a **regression model** to predict next return and fit the weights that best predict it.

### 3. **Reinforcement Learning Integration**

* Let your RL agent learn the weights $\alpha, \beta, \gamma$ as parameters over time.
* Plug into observation: `obs = price_custom`, where `price_custom = weighted(H, L, C)`.

---

## 🧪 Implementation Snippet (Python)

```python
def custom_price(high, low, close, alpha=0.3, beta=0.3, gamma=0.4):
    return alpha * high + beta * low + gamma * close
```

Or, to optimize weights dynamically:

```python
def adaptive_price(high, low, close, vol):
    alpha = 0.3 + 0.2 * vol
    beta = 0.3 - 0.1 * vol
    gamma = 1 - alpha - beta
    return alpha * high + beta * low + gamma * close
```

---

## 🧠 Strategic Use in a Trading Pipeline

* Use as the **main input price** to:

  * Indicators (RSI, MACD, Bollinger)
  * Pattern detectors (reversals, breakouts)
  * RL environments (in `price_matrix`)
* Apply it to **volume-weighted** or **volatility-adjusted** views
* Use difference or change in this price to detect **price aggression** or **liquidity pressure**



In [None]:
# === STEP 1: Get price series ===
# dates = pd.date_range(start="2021-01-01", periods=15, freq="M")
# prices = np.array([
#     30000, 33000, 29000, 35000, 34000,
#     38000, 36000, 42000, 40000, 48000,
#     47000, 52000, 50000, 58000, 62000
# ], dtype=float)

# Extract price series
dates = ohlcv["timestamp"].values
highs = ohlcv["high"].values
lows = ohlcv["low"].values
closes = ohlcv["close"].values
volumes = ohlcv["volume"].values

prices = closes


# Advanced Decision Making


using Reinforcement Learning

Use a conbinancetion of the Pivots and the Price prediciton

to train an LSTM Policy, use a gymnasium env, use stable base lines 3, use the reccurentppo from sb3 contrib

In [None]:
import datetime
import gymnasium as gym
from enum import IntEnum
from stable_baselines3 import PPO
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.env_util import make_vec_env

## Define Utils

In [None]:
class Action(IntEnum):
  LONG_ENTER = 0
  LONG_EXIT = 1
  SHORT_ENTER = 2
  SHORT_EXIT = 3
  NEUTRAL = 4

## Define Environments

In [None]:
# Define a custom Gymnasium environment
class TradingEnv(gym.Env):
    def __init__(self, data, window_size=60):
        super(TradingEnv, self).__init__()

        self.data = data
        self.window_size = window_size
        self.current_step = self.window_size # Start after the initial window
        self.max_steps = len(self.data) - 1

        # Action space: 0: Sell, 1: Hold, 2: Buy
        self.action_space = gym.spaces.Discrete(3)

        # Observation space: We'll include the price window, the price prediction,
        # and the current price's relation to recent pivot points.
        # The size will depend on the window size + prediction features + pivot features.
        # Let's assume we add 1 for the price prediction and 2 for pivot relation (e.g., distance to nearest support/resistance).
        # This is a simplified example, you'll need to carefully design your observation space.
        self.observation_space = gym.spaces.Box(
            low=0, high=1, shape=(self.window_size + 1 + 2,), dtype=np.float32
        )

        # Initial state
        self.reset()

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = self.window_size
        self.balance = 10000  # Starting balance
        self.shares_held = 0
        self.net_worth = self.balance
        self.positions = [] # Track buy/sell actions and prices
        self.episode_return = 0 # Track return for the episode

        # Get the initial observation
        observation = self._get_observation()
        info = {} # Optional information

        return observation, info

    def _get_observation(self):
        # Get the price data window
        window_start = self.current_step - self.window_size
        window_end = self.current_step
        price_window = self.data['Close'].iloc[window_start:window_end].values

        # Normalize the price window
        scaler = MinMaxScaler(feature_range=(0, 1))
        # Fit on the entire historical data to maintain consistent scaling
        # Alternatively, fit on a rolling window if you want adaptive scaling
        scaled_price_window = scaler.fit_transform(price_window.reshape(-1, 1)).flatten()

        # Get the current price for prediction
        current_price_sequence = self.data['Close'].iloc[self.current_step - self.window_size + 1 : self.current_step + 1].values
        current_price_sequence_scaled = scaler.transform(current_price_sequence.reshape(-1, 1)).flatten()
        # Reshape for the LSTM model input
        current_batch = current_price_sequence_scaled.reshape(1, self.window_size, 1)

        # Get the price prediction
        # Ensure the LSTM model is loaded and available
        # Assuming 'model' is your trained LSTM prediction model
        try:
            next_price_prediction_scaled = model.predict(current_batch)
            next_price_prediction = scaler.inverse_transform(next_price_prediction_scaled)[0, 0]
            # Scale the predicted price to fit the observation space range (0, 1)
            # This requires fitting the scaler on the entire dataset or a consistent range
            # A simple approach is to scale based on the min/max of the data the RL agent sees
            # For simplicity, let's scale relative to the current window's price range
            # You'll need a more robust scaling strategy in a real application
            min_price_window = price_window.min()
            max_price_window = price_window.max()
            predicted_price_scaled_obs = (next_price_prediction - min_price_window) / (max_price_window - min_price_window)
            predicted_price_scaled_obs = np.clip(predicted_price_scaled_obs, 0, 1) # Ensure it's within [0, 1]
        except NameError:
             # Handle the case where the LSTM model is not defined
             # You should ensure 'model' is loaded before creating the environment
             print("LSTM prediction model 'model' not found. Ensure it's loaded before initializing the environment.")
             # Return a default or placeholder value for the prediction
             predicted_price_scaled_obs = 0.5 # Or some other reasonable default

        # Get the pivot points for the current window
        # This is a placeholder. You need to integrate your pivot point calculation logic here.
        # Based on the Birch clustering, you have `center_prices` (from MinMaxScaler)
        # and `cluster_supports`, `cluster_resistances` (from OHLCV clustering).
        # You need to determine which pivot points are relevant for the current step.
        # A simple approach is to find the nearest support and resistance levels from the cluster centers.
        try:
             # Find the nearest support and resistance from the calculated centers
             current_price = self.data['Close'].iloc[self.current_step]
             if 'center_prices' in globals() and len(center_prices) > 0:
                  nearest_support = np.max(center_prices[center_prices <= current_price]) if np.any(center_prices <= current_price) else np.min(center_prices)
                  nearest_resistance = np.min(center_prices[center_prices >= current_price]) if np.any(center_prices >= current_price) else np.max(center_prices)

                  # Scale the distance to nearest support/resistance relative to the current price
                  distance_to_support = (current_price - nearest_support) / current_price
                  distance_to_resistance = (nearest_resistance - current_price) / current_price

                  # Scale these distances to fit within the observation space range (e.g., using min/max from the entire dataset)
                  # For simplicity here, let's just include the raw distances for now.
                  # You MUST implement proper scaling for these features.
                  # Let's scale them based on a hypothetical maximum relative distance (e.g., 100%)
                  scaled_distance_to_support = np.clip(distance_to_support, -1, 1) # Assuming distance can be positive or negative if price is below support
                  scaled_distance_to_resistance = np.clip(distance_to_resistance, -1, 1)
             else:
                  scaled_distance_to_support = 0.5 # Default if no pivot points are available
                  scaled_distance_to_resistance = 0.5

        except NameError:
             print("Pivot points (center_prices) not found. Ensure Birch clustering was run.")
             scaled_distance_to_support = 0.5
             scaled_distance_to_resistance = 0.5


        # Combine all features into the observation
        # Ensure all features are in the range [0, 1] if using Box(low=0, high=1)
        observation = np.concatenate((
            scaled_price_window,
            [predicted_price_scaled_obs],
            [scaled_distance_to_support, scaled_distance_to_resistance]
        ))

        # Ensure the observation has the correct shape
        assert observation.shape == self.observation_space.shape, f"Observation shape mismatch: {observation.shape} vs {self.observation_space.shape}"

        return observation

    def step(self, action):
        self.current_step += 1
        terminated = self.current_step >= self.max_steps
        reward = 0
        info = {}

        current_price = self.data['Close'].iloc[self.current_step]

        if action == 0: # Sell
            if self.shares_held > 0:
                sell_price = current_price
                self.balance += self.shares_held * sell_price
                self.shares_held = 0
                # Calculate reward based on profit/loss from selling
                # This is a simplified reward. A real trading environment needs a sophisticated reward function.
                # For example, reward could be based on the percentage return of the trade.
                trade_profit_loss = (sell_price - self.positions[-1]['buy_price']) * self.positions[-1]['shares'] if self.positions else 0
                reward += trade_profit_loss # Simple reward: profit/loss of the last trade
                self.positions = [] # Clear positions after selling
        elif action == 2: # Buy
            if self.balance > current_price: # Check if we have enough balance to buy at least one share
                buy_price = current_price
                # Decide how many shares to buy - a simple approach is to buy a fixed amount or use a percentage of balance
                shares_to_buy = self.balance // buy_price # Buy as many whole shares as possible
                if shares_to_buy > 0:
                  self.shares_held += shares_to_buy
                  self.balance -= shares_to_buy * buy_price
                  self.positions.append({'buy_price': buy_price, 'shares': shares_to_buy})
                  # Simple reward: a small positive reward for taking a buy action
                  # reward += 0.01 # Encourage buying

        # Calculate net worth at the end of the step
        self.net_worth = self.balance + self.shares_held * current_price

        # Additional reward considerations:
        # - Reward for increasing net worth
        # - Penalize holding for too long without profitable trades
        # - Penalize selling at a loss
        # - Reward for making a profitable trade
        # - Incorporate transaction costs

        # Simple reward based on the change in net worth
        # This can be problematic as it encourages holding assets even if price drops
        # Let's refine this: Reward the change in net worth only when a position is closed (sold)
        # If you want a reward at every step, it needs to be more nuanced.

        # Example of a more nuanced reward (simplified):
        # Reward proportional to the percentage change in net worth during the step.
        # This encourages actions that lead to net worth increase.
        # Ensure you handle the case where net_worth_before is zero if starting with 0 balance.
        # Let's calculate episode return for a final reward instead.
        # reward = (self.net_worth - self.initial_net_worth) # This will give cumulative profit/loss as reward

        # Let's use episode return and provide a final reward
        self.episode_return = self.net_worth - self.initial_net_worth # assuming initial_net_worth is set in reset

        if terminated:
            # Provide the final reward at the end of the episode
            # The reward is the total profit/loss of the episode
            reward = self.episode_return

        observation = self._get_observation()

        # Stub for truncated flag (if episode ends early due to conditions other than termination)
        truncated = False

        return observation, reward, terminated, truncated, info


In [None]:
# le and has a 'Close' column and a DateTime index.

# 1. Imports (already done at the top, just ensuring they are present for clarity)
# import gymnasium as gym
# from stable_baselines3 import PPO
# from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
# from sb3_contrib import RecurrentPPO
# from stable_baselines3.common.envs import DummyVecEnv
# from stable_baselines3.common.vec_env import VecNormalize
# from enum import IntEnum
# import numpy as np
# import pandas as pd
# import torch as th
# from torch import nn

# 2. Define the Action Enum and TradingEnv (already provided)
# Make sure the TradingEnv is correctly implemented and the observation space matches the features used.
# The current `TradingEnv` definition uses a Box observation space of shape `(self.window_size + 1 + 2,)`.

# Correct the `TradingEnv` to use the defined `Action` IntEnum
class TradingEnv(gym.Env):
    def __init__(self, data, window_size=60):
        super(TradingEnv, self).__init__()

        self.data = data
        self.window_size = window_size
        self.current_step = self.window_size # Start after the initial window
        self.max_steps = len(self.data) - 1

        # Action space: Use the defined Action Enum
        self.action_space = gym.spaces.Discrete(len(Action))

        # Observation space: We'll include the price window, the price prediction,
        # and the current price's relation to recent pivot points.
        # The size will depend on the window size + prediction features + pivot features.
        # Let's assume we add 1 for the price prediction and 2 for pivot relation (e.g., distance to nearest support/resistance).
        # This is a simplified example, you'll need to carefully design your observation space.
        self.observation_space = gym.spaces.Box(
            low=-np.inf, high=np.inf, shape=(self.window_size + 1 + 2,), dtype=np.float32
            # Using -np.inf and np.inf as we might not have strict bounds on scaled features
            # If you ensure all features are scaled to [0, 1], you can keep low=0, high=1
        )

        # Initial state and other attributes
        self.initial_balance = 10000 # Define this as an attribute
        self.reset() # Call reset here to initialize attributes like balance, shares_held, etc.

    def reset(self, seed=None, options=None):
        super().reset(seed=seed)
        self.current_step = self.window_size
        self.balance = self.initial_balance # Starting balance
        self.shares_held = 0
        self.net_worth = self.balance
        self.positions = [] # Track buy/sell actions and prices
        # self.episode_return = 0 # This will be calculated based on final net_worth - initial_net_worth
        self.initial_net_worth = self.net_worth # Store initial net worth

        # Get the initial observation
        observation = self._get_observation()
        info = {} # Optional information

        return observation, info

    def _get_observation(self):
        # Get the price data window
        window_start = self.current_step - self.window_size
        window_end = self.current_step
        price_window = self.data['Close'].iloc[window_start:window_end].values

        # Normalize the price window using MinMaxScaler fit on the *entire* dataset
        # Or fit on a sufficiently large historical window to avoid lookahead bias if simulating trading
        # For training, fitting on a representative historical period might be acceptable.
        # Let's assume a global scaler or fit on the data used for training the RL agent.
        # A simple approach for this example is to fit on the current window, but be aware of bias.
        # A more robust approach uses a scaler fit on a large historical dataset or a rolling fit.

        # For demonstration, let's create a scaler instance for the current window prices
        # In a real scenario, use a pre-fit scaler or a rolling scaler.
        temp_scaler_price_window = MinMaxScaler(feature_range=(0, 1))
        # Need to reshape for the scaler
        scaled_price_window = temp_scaler_price_window.fit_transform(price_window.reshape(-1, 1)).flatten()


        # Get the current price sequence for prediction
        # Ensure the sequence has the correct length (window_size) for the Keras model
        prediction_sequence_start = self.current_step - self.window_size
        prediction_sequence_end = self.current_step
        current_price_sequence = self.data['Close'].iloc[prediction_sequence_start : prediction_sequence_end].values

        # Scale this sequence using the same scaler used for the Keras model training
        # Assuming 'scaler' from the Keras preprocessing section is available globally or passed.
        # This 'scaler' was fit on the 'Close' price data used for Keras training.
        try:
            current_price_sequence_scaled = scaler.transform(current_price_sequence.reshape(-1, 1)).flatten()
            # Reshape for the LSTM model input: (1, window_size, 1)
            current_batch = current_price_sequence_scaled.reshape(1, self.window_size, 1)

            # Get the price prediction using the trained Keras model
            next_price_prediction_scaled = model.predict(current_batch, verbose=0) # Add verbose=0 to reduce output
            next_price_prediction = scaler.inverse_transform(next_price_prediction_scaled)[0, 0]

            # Scale the predicted price to fit the observation space range (e.g., [0, 1])
            # We need a consistent way to scale this predicted price.
            # One approach is to scale it relative to the min/max of the entire training data the RL agent sees.
            # Or scale it relative to the range of the current price window (less robust).
            # Let's re-use the temp_scaler_price_window from the current window for simplicity in this example.
            # A better approach is to scale based on a long-term price range.
            predicted_price_scaled_obs = temp_scaler_price_window.transform(np.array([[next_price_prediction]]))[0, 0]
            predicted_price_scaled_obs = np.clip(predicted_price_scaled_obs, -1, 2) # Clip to a reasonable range, not necessarily [0, 1] if using -inf, inf

        except NameError:
             print("Keras prediction model 'model' or scaler not found. Ensure they are available.")
             # Return a default or placeholder value for the prediction
             predicted_price_scaled_obs = 0.0 # Using 0.0 as a default placeholder


        # Get the pivot points relation for the current step
        # Assuming `center_prices` (from Birch clustering) is available globally or passed.
        try:
             current_price = self.data['Close'].iloc[self.current_step]
             scaled_distance_to_support = 0.0 # Placeholder
             scaled_distance_to_resistance = 0.0 # Placeholder

             if 'center_prices' in globals() and len(center_prices) > 0:
                  # Find the nearest support and resistance from the calculated centers
                  supports = center_prices[center_prices <= current_price]
                  resistances = center_prices[center_prices >= current_price]

                  nearest_support = np.max(supports) if len(supports) > 0 else np.min(center_prices) # Use closest if none below, or min of all
                  nearest_resistance = np.min(resistances) if len(resistances) > 0 else np.max(center_prices) # Use closest if none above, or max of all

                  # Calculate distance to nearest support/resistance
                  distance_to_support = current_price - nearest_support
                  distance_to_resistance = nearest_resistance - current_price

                  # Scale these distances. A simple approach is to divide by current price or a moving average.
                  # Or, scale relative to the price range of the trading data.
                  # For this example, let's just use the raw distances for now, but you should scale them.
                  # If using raw distances, the observation space low/high should be adjusted.
                  # Let's try scaling by the current price for a relative distance.
                  # Add a small epsilon to avoid division by zero if current_price is 0.
                  epsilon = 1e-8
                  scaled_distance_to_support = distance_to_support / (current_price + epsilon)
                  scaled_distance_to_resistance = distance_to_resistance / (current_price + epsilon)


             else:
                  scaled_distance_to_support = 0.0 # Default if no pivot points
                  scaled_distance_to_resistance = 0.0

        except NameError:
             print("Pivot points (center_prices) not found. Ensure Birch clustering was run.")
             scaled_distance_to_support = 0.0
             scaled_distance_to_resistance = 0.0


        # Combine all features into the observation
        observation = np.concatenate((
            scaled_price_window,
            [predicted_price_scaled_obs],
            [scaled_distance_to_support, scaled_distance_to_resistance]
        ))

        # Ensure the observation has the correct shape
        assert observation.shape == self.observation_space.shape, f"Observation shape mismatch: {observation.shape} vs {self.observation_space.shape}"

        return observation

    def step(self, action):
        # Action mapping based on the Action Enum
        action = Action(action)

        self.current_step += 1
        terminated = self.current_step >= self.max_steps
        reward = 0
        info = {}

        current_price = self.data['Close'].iloc[self.current_step]

        # Implement trading logic based on the defined actions
        if action == Action.LONG_ENTER:
            # Check if we can enter a long position (e.g., not already in a position, enough balance)
            if self.shares_held == 0 and self.balance > current_price:
                shares_to_buy = self.balance // current_price
                if shares_to_buy > 0:
                    self.shares_held += shares_to_buy
                    self.balance -= shares_to_buy * current_price
                    # Record the entry price and shares for this position
                    self.positions.append({'type': 'long', 'entry_price': current_price, 'shares': shares_to_buy})
                    # Optional: Small positive reward for entering a position (depends on strategy)
                    # reward += 0.01

        elif action == Action.LONG_EXIT:
            # Check if we are in a long position to exit
            if self.shares_held > 0 and len(self.positions) > 0 and self.positions[-1]['type'] == 'long':
                exit_price = current_price
                entry_price = self.positions[-1]['entry_price']
                shares = self.positions[-1]['shares']

                self.balance += self.shares_held * exit_price
                self.shares_held = 0
                self.positions = [] # Clear position

                # Calculate reward for this trade
                profit_loss = (exit_price - entry_price) * shares
                reward += profit_loss # Reward is the profit/loss of the trade

        elif action == Action.SHORT_ENTER:
             # Implement short entry logic (more complex: involves borrowing shares)
             # For simplicity, let's not implement shorting in this example, or make it a no-op
             pass # No shorting for now

        elif action == Action.SHORT_EXIT:
             # Implement short exit logic
             pass # No shorting for now

        elif action == Action.NEUTRAL:
            # Holding position or staying out of the market
            # No transaction, potentially a small holding penalty or reward based on market movement
            pass

        # Calculate net worth at the end of the step
        self.net_worth = self.balance + self.shares_held * current_price

        # Additional rewards/penalties can be added here, e.g.,
        # - Time penalty for holding a losing position
        # - Reward for increasing net worth over time (but be careful with this)
        # - Penalties for illegal actions (e.g., trying to sell when no shares are held)

        # Final reward is given at the end of the episode
        if terminated:
            # Calculate the total profit/loss for the episode
            total_episode_return = self.net_worth - self.initial_net_worth
            reward += total_episode_return # Reward the total profit/loss

        observation = self._get_observation()

        # Stub for truncated flag (if episode ends early due to conditions other than termination)
        truncated = False

        info = {"net_worth": self.net_worth} # Add info if needed

        return observation, reward, terminated, truncated, info

    def render(self):
        # Implement rendering if needed (e.g., plotting the trading process)
        pass

    def close(self):
        # Clean up resources if needed
        pass



## Define Agents

In [None]:

import torch as th
from torch import nn
from stable_baselines3 import PPO
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
import gymnasium as gym
from gymnasium import spaces
from enum import IntEnum
from datetime import timedelta

# Custom Feature Extractor
class CustomFeatureExtractor(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """
    def __init__(self, observation_space: gym.Space, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        # Assuming observation space is Box((window_size + 1 + 2),)
        n_features = observation_space.shape[0]
        self.linear = nn.Sequential(nn.Linear(n_features, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        return self.linear(observations)

# Define the custom policy network
class CustomPolicy(nn.Module):
    def __init__(self, feature_extractor, lstm_hidden_size, action_space):
        super(CustomPolicy, self).__init__()
        self.feature_extractor = feature_extractor
        extracted_features_dim = feature_extractor.features_dim

        # LSTM layer
        # Assuming the feature extractor outputs a flat vector per timestep
        # We need to reshape the observation to have a time dimension for LSTM
        # For simplicity, let's assume the observation is (batch_size, feature_dim)
        # and we treat each step as a single timestep for LSTM.
        # If your observation includes sequences (like the price window), you need to
        # process the sequence part separately or design the feature extractor
        # to handle sequences.
        # For now, let's assume the feature extractor flattens the observation.
        # If your observation is (batch_size, seq_len, features), adjust LSTM input_size and reshape.
        self.lstm = nn.LSTM(extracted_features_dim, lstm_hidden_size, batch_first=True)

        # Policy and value heads
        self.policy_net = nn.Sequential(
            nn.Linear(lstm_hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, action_space.n)
        )
        self.value_net = nn.Sequential(
            nn.Linear(lstm_hidden_size, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, observations: th.Tensor, hidden_state):
        # Reshape observations for LSTM: (batch_size, sequence_length, input_size)
        # Here, sequence_length is 1 as we process one observation at a time from the feature extractor
        # If your feature extractor returns a sequence, adjust accordingly.
        features = self.feature_extractor(observations) # shape: (batch_size, extracted_features_dim)
        features = features.unsqueeze(1) # shape: (batch_size, 1, extracted_features_dim)

        lstm_out, new_hidden_state = self.lstm(features, hidden_state) # shape: (batch_size, 1, lstm_hidden_size)
        lstm_out = lstm_out.squeeze(1) # shape: (batch_size, lstm_hidden_size)

        action_probs = self.policy_net(lstm_out)
        value = self.value_net(lstm_out)

        return action_probs, value, new_hidden_state

    def _get_action_dist_from_latent(self, latent_pi: th.Tensor):
        """
        Retrieve action distribution from the latent representation of the policy.

        :param latent_pi: (th.Tensor) Latent representation of the policy.
        :return: (CategoricalDistribution) Action distribution.
        """
        # This is a placeholder. You'll need to implement this based on your policy head structure.
        # For a discrete action space and a final layer outputting logits:
        from stable_baselines3.common.distributions import CategoricalDistribution
        logits = self.policy_net(latent_pi)
        return CategoricalDistribution(logits=logits)

    def evaluate_actions(self, observations: th.Tensor, actions: th.Tensor, hidden_state):
        """
        Evaluate actions given observations.

        :param observations: (th.Tensor) Observations.
        :param actions: (th.Tensor) Actions.
        :param hidden_state: (tuple) LSTM hidden state (h, c).
        :return: (th.Tensor, th.Tensor, th.Tensor) Log likelihood of actions, value estimates, entropy of distribution.
        """
        # Reshape observations for LSTM: (batch_size, sequence_length, input_size)
        features = self.feature_extractor(observations) # shape: (batch_size, extracted_features_dim)
        features = features.unsqueeze(1) # shape: (batch_size, 1, extracted_features_dim)


        # Need to handle the sequence length and hidden state management for the PPO update.
        # This is complex and often handled by the Recurrent PPO implementation itself.
        # The Recurrent PPO in sb3-contrib expects a model that can manage the hidden state
        # across sequences.

        # Let's simplify and assume the input observations already have a sequence dimension
        # and the hidden state is handled by the wrapper.
        # If using Recurrent PPO, the `forward` method is often structured differently
        # to process sequences and manage the hidden state.

        # Assuming the input `observations` to `evaluate_actions` is already shaped
        # (batch_size, sequence_length, observation_space_shape)
        # We need to apply the feature extractor and LSTM timestep by timestep or as a sequence.

        # This requires a more detailed integration with the Recurrent PPO structure.
        # For a standard Recurrent PPO setup, your model's `forward` method
        # would typically take `observations` (batch_size, sequence_length, ...)
        # and `lstm_states` (tuple of (batch_size, lstm_hidden_size)) and return
        # `actions`, `values`, `log_prob`, `lstm_states`.

        # Let's provide a basic structure assuming the input `observations` is (batch_size, observation_shape)
        # and we need to process them sequentially with a dummy sequence length for demonstration.
        # In a real Recurrent PPO, the sequence length is determined by the sampler.

        # Dummy sequence length for evaluation (e.g., 1 for independent steps in evaluation)
        # In training with RecurrentPPO, this would be the actual sequence length.
        sequence_length = 1
        batch_size = observations.shape[0]

        # Reshape observations to add sequence length dimension: (batch_size, sequence_length, observation_shape)
        observations = observations.unsqueeze(1)


        # Need to pass hidden_state through the LSTM
        features = self.feature_extractor(observations.view(batch_size * sequence_length, -1))
        features = features.view(batch_size, sequence_length, -1)


        lstm_out, _ = self.lstm(features, hidden_state) # shape: (batch_size, sequence_length, lstm_hidden_size)
        # Take the output of the last timestep in the sequence
        lstm_out_last = lstm_out[:, -1, :] # shape: (batch_size, lstm_hidden_size)


        action_logits = self.policy_net(lstm_out_last)
        values = self.value_net(lstm_out_last)

        # Calculate action distribution
        from stable_baselines3.common.distributions import CategoricalDistribution
        distribution = CategoricalDistribution(logits=action_logits)

        log_prob = distribution.log_prob(actions)
        entropy = distribution.entropy()

        return log_prob, values, entropy


# You would integrate this custom policy with RecurrentPPO like this:
# from sb3_contrib import RecurrentPPO
# from stable_baselines3.common.envs import DummyVecEnv
# from stable_baselines3.common.vec_env import VecNormalize
# from stable_baselines3.common.torch_layers import BaseFeaturesExtractor

# # Create the environment
# env = DummyVecEnv([lambda: TradingEnv(advpp_data)])
# # Optional: Normalize the observation and reward
# # env = VecNormalize(env, norm_obs=True, norm_reward=False)

# # Define policy kwargs to use the custom feature extractor and network architecture
# policy_kwargs = dict(
#     features_extractor_class=CustomFeatureExtractor,
#     features_extractor_kwargs=dict(features_dim=64), # Feature extractor output dim
#     # Define the network architecture for the policy and value functions after the LSTM
#     # This is where you would specify the layers that take the LSTM output as input
#     # and output the action logits and value.
#     # The RecurrentPPO automatically handles the LSTM layer when you specify
#     # 'enable_lstm': True in the policy_kwargs.
#     # However, if you want a custom architecture including LSTM and other layers,
#     # defining a custom model class (like CustomPolicy above) and integrating it
#     # with RecurrentPPO's `policy_aliases` or by overriding policy classes
#     # is a more advanced approach.

#     # A simpler approach with RecurrentPPO is to use its built-in LSTM support
#     # and define the CNN/MLP features extractor that runs *before* the LSTM.
#     # The observation space for the LSTM will be the output of the feature extractor.

#     # Let's go back to the simpler RecurrentPPO structure: define the feature extractor
#     # that processes the potentially structured observation into a flat vector,
#     # and let RecurrentPPO add the LSTM on top of this flat vector.

#     # If your observation is flat (like in TradingEnv), the CustomFeatureExtractor
#     # is suitable to process this flat vector. RecurrentPPO will then add an LSTM
#     # layer that takes the output of the CustomFeatureExtractor as input.

#     enable_lstm=True,
#     lstm_hidden_size=128, # Size of the LSTM hidden state
#     # Specify the network architecture for the policy and value heads *after* the LSTM.
#     # These networks take the LSTM output as input.
#     net_arch=[dict(pi=[64], vf=[64])] # Example: one hidden layer of 64 units for both policy and value
#     # You can customize this further based on your needs.
#     # The LSTM output size is `lstm_hidden_size`.
# )

# # Instantiate the agent
# model = RecurrentPPO("MlpLstmPolicy", env, verbose=1, policy_kwargs=policy_kwargs)

# # Train the agent
# # model.learn(total_timesteps=10000)

# # Save the model
# # model.save("recurrent_ppo_trading")

# # Load the trained model
# # model = RecurrentPPO.load("recurrent_ppo_trading")

# # Evaluate the agent
# # obs = env.reset()
# # lstm_states = None
# # num_envs = 1 # Number of environments
# # episode_starts = np.ones((num_envs,), dtype=bool)
# # for _ in range(1000):
# #     action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
# #     obs, reward, done, info = env.step(action)
# #     episode_starts = done
# #     if done:
# #         obs = env.reset()
# #         lstm_states = None # Reset LSTM states when episode ends

# The provided code already has a TradingEnv and a trained Keras model (`model`) for price prediction.
# The task is to build a *custom policy network* to integrate the various observation features
# (price window, prediction, pivot points) and train an RL agent.

# Let's define the components needed for integrating with RecurrentPPO.
# We will use the `TradingEnv` defined previously.
# The observation space of `TradingEnv` is a flat Box space.

# The features in the observation are:
# 1. `scaled_price_window`: `window_size` features
# 2. `predicted_price_scaled_obs`: 1 feature
# 3. `scaled_distance_to_support`: 1 feature
# 4. `scaled_distance_to_resistance`: 1 feature
# Total observation space dimension: `window_size + 1 + 2`

# RecurrentPPO's `MlpLstmPolicy` first passes the observation through an MLP (controlled by `net_arch`)
# and *then* through an LSTM. This means the input to the LSTM will be the output of the MLP.

# If we want the LSTM to process the sequence of raw (or pre-processed) observations
# over time steps within an episode, the `TradingEnv` needs to return observations
# that represent sequences, or we need a custom `features_extractor` that restructures
# the batch of observations from the vector environment into sequences for the LSTM.
# Stable-Baselines3's `VecEnv` collects observations in batches. RecurrentPPO handles
# the sequence creation and state passing based on the `episode_starts` information.

# So, the standard `MlpLstmPolicy` with `enable_lstm=True` should work, with a custom
# `features_extractor` if the raw observation space needs initial processing before the LSTM.
# In our case, the `TradingEnv` already provides a flat observation vector at each step.
# We can use the `CustomFeatureExtractor` to process this flat vector before it goes into the LSTM.

# Let's define the components and set up the RecurrentPPO agent.
# Assuming `advpp_data` (your OHLCV DataFrame used for training the Keras model)
# is availab
# 3. Define the Custom Feature Extractor (already provided)
class CustomFeatureExtractor(BaseFeaturesExtractor):
    """
    :param observation_space: (gym.Space)
    :param features_dim: (int) Number of features extracted.
        This corresponds to the number of unit for the last layer.
    """
    def __init__(self, observation_space: gym.Space, features_dim: int = 64):
        super().__init__(observation_space, features_dim)
        n_features = observation_space.shape[0]
        # Use a simple linear layer followed by ReLU as the feature extractor
        self.linear = nn.Sequential(nn.Linear(n_features, features_dim), nn.ReLU())

    def forward(self, observations: th.Tensor) -> th.Tensor:
        # The input `observations` to the feature extractor for RecurrentPPO
        # when `enable_lstm=True` and using MlpLstmPolicy will be
        # (batch_size, sequence_length, observation_space_shape)
        # We need to process each timestep's observation.
        # Reshape to (batch_size * sequence_length, observation_space_shape)
        batch_size, sequence_length, _ = observations.shape
        observations = observations.view(batch_size * sequence_length, -1)
        features = self.linear(observations)
        # Reshape back to (batch_size, sequence_length, features_dim)
        features = features.view(batch_size, sequence_length, -1)
        return features


# 4. Instantiate the environment and agent
# Make sure `advpp_data` is loaded and available (assuming it's the DataFrame from previous cells)
# Also ensure the Keras `model` and `scaler` are trained and available.

# Create the environment
# Use a DummyVecEnv for simplicity with a single environment
env = DummyVecEnv([lambda: TradingEnv(advpp_data)])

# Optional: Normalize the observation and reward.
# Normalizing observations is generally recommended for neural networks.
# Normalizing rewards can help stabilize training.
env = VecNormalize(env, norm_obs=True, norm_reward=False) # Start without reward normalization

# Define policy kwargs for RecurrentPPO
policy_kwargs = dict(
    features_extractor_class=CustomFeatureExtractor,
    features_extractor_kwargs=dict(features_dim=64), # Output dimension of the feature extractor
    enable_lstm=True, # Enable the built-in LSTM layer in MlpLstmPolicy
    lstm_hidden_size=128, # Hidden size of the LSTM
    # Define the network architecture for the policy and value heads *after* the LSTM.
    # The input to these networks is the output of the LSTM (lstm_hidden_size).
    net_arch=[dict(pi=[64], vf=[64])] # Example: one hidden layer of 64 units for both policy and value heads
    # The final output layer sizes are determined by the action space (for policy) and 1 (for value).
)

# Instantiate the RecurrentPPO agent
# Use the MlpLstmPolicy provided by sb3_contrib
model = RecurrentPPO("MlpLstmPolicy", env, verbose=1, policy_kwargs=policy_kwargs, tensorboard_log="./trading_ppo_lstm_tensorboard/")

# 5. Train the agent
print("Starting training...")
# Adjust total_timesteps based on your data size and training goals
# Training time depends on the number of steps, batch size, etc.
try:
    model.learn(total_timesteps=100000) # Train for a certain number of timesteps
    print("Training finished.")
except Exception as e:
    print(f"An error occurred during training: {e}")

# 6. Save the trained model
try:
    model.save("recurrent_ppo_trading_policy")
    print("Model saved successfully.")
except Exception as e:
    print(f"An error occurred while saving the model: {e}")


# 7. Evaluate the trained agent (Optional)
# Need a separate evaluation environment
# eval_env = DummyVecEnv([lambda: TradingEnv(advpp_data)]) # Use a different data split for evaluation if available
# eval_env = VecNormalize(eval_env, norm_obs=True, norm_reward=False, training=False, norm_and_reward_env=False) # Use the same normalization but in evaluation mode

# print("Starting evaluation...")
# obs = eval_env.reset()
# lstm_states = None # Reset LSTM states for evaluation
# num_envs = eval_env.num_envs
# # Episode start signals for the LSTM, from VecEnv
# episode_starts = np.ones((num_envs,), dtype=bool)

# total_reward = 0
# n_steps = 0
# max_eval_steps = 1000 # Number of steps to evaluate for

# try:
#     for _ in range(max_eval_steps):
#         action, lstm_states = model.predict(obs, state=lstm_states, episode_start=episode_starts, deterministic=True)
#         obs, reward, done, info = eval_env.step(action)
#         total_reward += reward
#         n_steps += 1
#         episode_starts = done # Update episode_starts based on done flags

#         if done:
#              # Log or process episode end information
#              for i, d in enumerate(done):
#                   if d:
#                        print(f"Episode finished after {n_steps} steps. Return: {info[i]['net_worth'] - eval_env.get_original_obs()[i]['initial_net_worth']:.2f}") # Access original obs if needed
#                        # Note: accessing info from VecEnv needs careful handling, especially with done=True
#                        # Use `info` returned by `eval_env.step` which is a list of dictionaries for vectorized envs.
#              obs = eval_env.reset()
#              lstm_states = None # Reset LSTM states when episode ends
#              episode_starts = np.ones((num_envs,), dtype=bool) # All episodes start after reset
#              n_steps = 0 # Reset step counter if evaluating a single episode

#     print(f"Evaluation finished after {max_eval_steps} steps. Average reward: {total_reward / max_eval_steps:.2f}")

# except Exception as e:
#      print(f"An error occurred during evaluation: {e}")

# To run this code, ensure:
# 1. The `advpp_data` DataFrame is loaded and contains a 'Close' column with a DateTime index.
# 2. The Keras `model` (for price prediction) is trained and available globally or passed to the environment.
# 3. The Keras `scaler` (used for scaling the price data for Keras model) is available globally or passed.
# 4. The `center_prices` list (from Birch clustering) is available globally or passed to the environment.
# 5. Install `sb3_contrib`: `!pip install sb3-contrib`
# 6. Install `gymnasium`: `!pip install gymnasium`
```

## Train Agent

In [None]:
# --- Data Preparation ---
# Assuming 'ohlcv' DataFrame is available from previous steps and contains 'timestamp', 'Close'
# Ensure 'timestamp' is a datetime index
if not isinstance(ohlcv_copy.index, pd.DatetimeIndex):
    ohlcv_copy['timestamp'] = pd.to_datetime(ohlcv_copy['timestamp'])
    ohlcv_copy = ohlcv_copy.set_index('timestamp')

# Select a subset of the data for training the RL agent
# Make sure this slice is long enough for your window size and training
rl_data = ohlcv_copy.loc['2021-01-01':'2023-01-01'].copy() # Adjust dates as needed

# Ensure there's enough data after slicing
if len(rl_data) < 100: # Arbitrary minimum length
    raise ValueError("Not enough data for the specified date range. Please adjust dates.")

# --- Environment Setup ---
WINDOW_SIZE = 60 # Needs to match the window size used for price prediction
# Make sure the LSTM model 'model' is trained and available in the global scope
# If not, you need to load it here or pass it to the TradingEnv

# Create the environment
env = TradingEnv(data=rl_data, window_size=WINDOW_SIZE)

# Vectorize the environment
vec_env = make_vec_env(lambda: env, n_envs=1)

# --- Model Training ---
# Use RecurrentPPO from sb3_contrib for LSTM policy
# policy='LstmPolicy' uses a built-in LSTM network.
# You might need a custom policy network to integrate the various observation features.
# For a simple start, let's use the default LstmPolicy and ensure the observation space is flat.

# The observation space is already flattened in _get_observation.

# Initialize the agent
# Adjust hyperparameters as needed
model_rl = RecurrentPPO("MlpLstmPolicy", vec_env, verbose=1, device="auto")

# Train the agent
# Adjust total_timesteps based on your data size and desired training length
model_rl.learn(total_timesteps=10000) # Increase for better training

# --- Save the trained agent ---
model_rl.save("recurrent_ppo_trading_agent")


## Evaluate Agent

In [None]:

# --- Evaluation (Optional) ---
# Load the trained agent
# model_rl = RecurrentPPO.load("recurrent_ppo_trading_agent")

# Run the trained agent on a test set
# test_data = ohlcv.loc['2023-01-02':'2024-01-01'].copy() # Adjust dates
# test_env = TradingEnv(data=test_data, window_size=WINDOW_SIZE)
# obs, _ = test_env.reset()
# terminated, truncated = False
# total_reward = 0
# while not terminated and not truncated:
#     action, _states = model_rl.predict(obs, deterministic=True)
#     obs, reward, terminated, truncated, info = test_env.step(action)
#     total_reward += reward
#     # Render the environment if it supports it (optional)
#     # test_env.render()

# print(f"Total reward on test set: {total_reward}")

# --- Integration with Pivots and Predictions ---
# The current environment already incorporates the price prediction and pivot point information
# (as implemented in the _get_observation method).
# The LSTM policy learns to use these features to make decisions (Buy/Hold/Sell).

# Further steps for a more robust integration:
# 1. Refine the observation space: Experiment with different ways to represent pivot points
#    (e.g., distance to nearest support/resistance, number of pivots in a window,
#    whether the current price is above/below a pivot zone).
#    Include more prediction horizons (e.g., 1-day, 4-day, 7-day predictions).
#    Include other relevant features (e.g., volume, technical indicators, sentiment).
# 2. Design a sophisticated reward function: This is crucial for training an effective trading agent.
#    Consider factors like transaction costs, risk, drawdown, Sharpe ratio, etc.
# 3. Hyperparameter tuning: Optimize the RL agent's hyperparameters for better performance.
# 4. Custom Policy Network: Implement a custom policy network in stable-baselines3
#    to have more control over how the LSTM processes the various observation features.
#    For example, you might want separate branches for price window processing,
#    prediction features, and pivot features before combining them.
# 5. Backtesting: Rigorously backtest the trained agent on unseen data to evaluate its profitability and risk.
# 6. Training Data: Ensure you have sufficient and representative data for training.
#    Consider using data from multiple instruments or timeframes.


## Live Decision

In [None]:
from gymnasium.wrappers import TimeLimit
from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.callbacks import BaseCallback
import gymnasium as gym
from datetime import timedelta

# Assuming 'model' is your trained LSTM prediction model and 'scaler' is its scaler

def make_live_decision(current_price_data: pd.DataFrame, model, scaler, pivot_centers) -> str:
    """
    Makes a live trading decision (Buy, Hold, Sell) based on the current price data,
    LSTM prediction, and pivot points using a trained RL agent.

    Args:
        current_price_data: DataFrame containing the latest price data,
                            including the window size required by the LSTM model.
                            Must have a 'Close' column.
        model: The trained Keras LSTM prediction model.
        scaler: The MinMaxScaler used to scale the training data for the LSTM model.
        pivot_centers: A list or array of pivot point prices (from Birch clustering).

    Returns:
        A string indicating the recommended action: 'Buy', 'Hold', or 'Sell'.
    """
    WINDOW_SIZE = 60 # Must match the window size used for RL training and prediction

    if len(current_price_data) < WINDOW_SIZE:
        print(f"Not enough data for the window size ({WINDOW_SIZE}). Cannot make a decision.")
        return 'Hold' # Default to hold if not enough data

    # --- Prepare Observation for the RL Agent ---

    # Get the price data window (latest WINDOW_SIZE closing prices)
    price_window = current_price_data['Close'].values[-WINDOW_SIZE:]

    # Normalize the price window using the same scaler as used for prediction
    # Note: In a real-world scenario, you might need to handle the scaler fitting
    # more robustly if the price range changes significantly over time.
    try:
      # Ensure the scaler is fitted on data that includes the range of the price_window
      # This might require refitting the scaler or using a scaler fitted on a wider range
      # For simplicity here, we assume the scaler is adequate.
      scaled_price_window = scaler.transform(price_window.reshape(-1, 1)).flatten()
    except Exception as e:
      print(f"Error scaling price window: {e}")
      # Handle scaling error - e.g., return 'Hold' or use a default scaled window
      return 'Hold'


    # Get the current price for prediction input (last WINDOW_SIZE prices)
    current_price_sequence = current_price_data['Close'].values[-WINDOW_SIZE:]
    current_price_sequence_scaled = scaler.transform(current_price_sequence.reshape(-1, 1)).flatten()
    current_batch = current_price_sequence_scaled.reshape(1, WINDOW_SIZE, 1)


    # Get the price prediction
    try:
        next_price_prediction_scaled = model.predict(current_batch)
        next_price_prediction = scaler.inverse_transform(next_price_prediction_scaled)[0, 0]
        # Scale the predicted price for the RL observation space (0, 1)
        # This scaling should be consistent with how it was done in the RL environment's _get_observation
        # For simplicity, scale relative to the range of the current window
        min_price_window = price_window.min()
        max_price_window = price_window.max()
        predicted_price_scaled_obs = (next_price_prediction - min_price_window) / (max_price_window - min_price_window)
        predicted_price_scaled_obs = np.clip(predicted_price_scaled_obs, 0, 1) # Ensure it's within [0, 1]

    except Exception as e:
         print(f"Error getting price prediction: {e}")
         predicted_price_scaled_obs = 0.5 # Default if prediction fails

    # Get pivot point relation for the current price
    try:
         current_price = current_price_data['Close'].iloc[-1] # Get the very last price
         if len(pivot_centers) > 0:
              # Find the nearest support and resistance from the calculated centers
              nearest_support = np.max(pivot_centers[pivot_centers <= current_price]) if np.any(pivot_centers <= current_price) else np.min(pivot_centers)
              nearest_resistance = np.min(pivot_centers[pivot_centers >= current_price]) if np.any(pivot_centers >= current_price) else np.max(pivot_centers)

              # Scale the distance consistently with the RL environment
              # This scaling needs to be carefully chosen and fixed.
              # Example: Use the maximum possible price difference in the dataset the RL was trained on
              # For simplicity here, using a basic relative scaling.
              # You MUST replace this with a proper scaling based on your RL training data.
              distance_to_support = (current_price - nearest_support) / current_price if current_price != 0 else 0
              distance_to_resistance = (nearest_resistance - current_price) / current_price if current_price != 0 else 0

              # Scale these distances to fit within [0, 1] or another appropriate range used in the RL observation space
              # Assuming [-1, 1] range scaling as in the example _get_observation
              scaled_distance_to_support = np.clip(distance_to_support, -1, 1)
              scaled_distance_to_resistance = np.clip(distance_to_resistance, -1, 1)
              # If your observation space is [0, 1], you'd need to map [-1, 1] to [0, 1]
              # e.g., (scaled_distance + 1) / 2
              scaled_distance_to_support = (scaled_distance_to_support + 1) / 2
              scaled_distance_to_resistance = (scaled_distance_to_resistance + 1) / 2

         else:
              scaled_distance_to_support = 0.5 # Default if no pivot points are available
              scaled_distance_to_resistance = 0.5

    except Exception as e:
         print(f"Error calculating pivot relation: {e}")
         scaled_distance_to_support = 0.5
         scaled_distance_to_resistance = 0.5


    # Combine all features into the observation
    # Ensure the order and scaling match the RL environment's observation space exactly
    observation = np.concatenate((
        scaled_price_window,
        [predicted_price_scaled_obs],
        [scaled_distance_to_support, scaled_distance_to_resistance]
    ))

    # Reshape the observation to match the RL agent's expected input shape (batch_size, observation_shape)
    # Since we are making one decision at a time, batch size is 1.
    # If using an LstmPolicy, the input shape is (batch_size, sequence_length, features) if not VecEnv
    # or (num_envs, observation_space.shape) for VecEnv.
    # The RecurrentPPO model with MlpLstmPolicy expects (num_envs, *observation_space.shape) for observation
    # and maintains recurrent states.
    # Let's assume the model was trained with a VecEnv of n_envs=1.
    # The observation needs to be (1, observation_space.shape).

    observation = observation.reshape(1, -1) # Reshape for VecEnv (num_envs=1)

    # --- Load the Trained RL Agent ---
    # Ensure the RL agent model is loaded
    try:
        # Assuming the model is saved as 'recurrent_ppo_trading_agent.zip' (default sb3 format)
        # If you saved it with a different name or format, adjust accordingly.
        agent = RecurrentPPO.load("recurrent_ppo_trading_agent", device="auto")

        # Get the recurrent states. For a single live prediction, we usually don't
        # have previous recurrent states. We can start with initial states (zeros).
        # This is a simplification; in a real live trading system, you'd maintain
        # the recurrent states between decisions.
        # The structure of the initial states depends on the policy network (e.g., LSTM layers).
        # For RecurrentPPO with MlpLstmPolicy, the states are usually (n_envs, 2, n_lstm_units)
        # where 2 is for (hidden_state, cell_state).
        # Let's get the structure from the loaded model's policy.
        # This might require accessing internal policy details or using a wrapper.
        # A simpler approach for a single prediction is often to use dummy initial states,
        # but this might not capture the full history effect of the LSTM.

        # A more robust way is to make a dummy prediction to get the state shape
        # Or examine the model's policy.

        # Let's assume the standard MlpLstmPolicy state shape for n_envs=1
        num_lstm_layers = 1 # Default for MlpLstmPolicy often has one LSTM layer
        lstm_units = 128 # Default LSTM units in MlpLstmPolicy, check model summary if possible
        # Initial recurrent states (hidden and cell states for each LSTM layer)
        # state_shape = (num_envs, num_lstm_layers * 2, lstm_units) # Incorrect structure
        # Correct structure for MlpLstmPolicy: (num_envs, 2, n_lstm_units)
        lstm_states = np.zeros((1, 2, lstm_units), dtype=np.float32) # Assuming 1 env, 2 states (h, c), lstm_units

        # Get the deterministic action and the next recurrent states
        # The `state` argument in `predict` is for recurrent policies
        action, next_lstm_states = agent.predict(observation, state=lstm_states, deterministic=True)

    except FileNotFoundError:
        print("Error: Trained RL agent model not found. Please ensure 'recurrent_ppo_trading_agent.zip' exists.")
        return 'Hold' # Default to hold if the agent is not found
    except Exception as e:
        print(f"Error loading or predicting with the RL agent: {e}")
        return 'Hold' # Default to hold on agent error


    # --- Interpret the Action ---
    # Map the integer action from the agent to a trading decision string
    action_map = {0: 'Sell', 1: 'Hold', 2: 'Buy'}
    decision = action_map.get(action.item(), 'Hold') # .item() to get scalar from numpy array

    return decision

# --- Example Usage ---
# You need to have:
# 1. Your trained Keras LSTM prediction 'model'
# 2. The 'scaler' used for the LSTM model
# 3. Your 'pivot_centers' from the Birch clustering

# --- Dummy Data and Objects for Demonstration ---
# In a real scenario, load your actual trained model, scaler, and pivot centers
# Ensure these objects are available in the global scope or passed to the function

# Dummy model and scaler (replace with your actual trained model and scaler)
try:
    # Attempt to use the pre-existing 'model' and 'scaler' from the notebook
    print("Using existing 'model' and 'scaler'.")
    if 'model' not in globals() or 'scaler' not in globals():
         raise NameError("LSTM 'model' or 'scaler' not found.")
    # Check if 'model' is a Keras Model
    if not isinstance(model, keras.Model):
         raise TypeError("'model' is not a Keras Model.")
    # Check if 'scaler' is a MinMaxScaler
    if not isinstance(scaler, MinMaxScaler):
         raise TypeError("'scaler' is not a MinMaxScaler.")

except (NameError, TypeError) as e:
    print(f"Could not use existing 'model' and 'scaler': {e}")
    print("Creating dummy model and scaler for demonstration. REPLACE THIS.")
    # Create dummy model and scaler for demonstration purposes
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Input, LSTM, Dense, Reshape, Dropout, BatchNormalization
    from tensorflow.keras.layers import Attention, Multiply
    from sklearn.preprocessing import MinMaxScaler
    import numpy as np

    # Dummy scaler - fit it on some dummy data that spans a reasonable price range
    dummy_prices = np.linspace(100, 400, 5000).reshape(-1, 1) # Simulate a price range
    scaler = MinMaxScaler(feature_range=(0, 1))
    scaler.fit(dummy_prices)

    # Dummy LSTM model - needs to have the same input shape and expected output as your real model
    # This dummy model just predicts the last price scaled
    # Replace this with loading your actual model!
    def build_dummy_model(input_shape):
      inputs = Input(shape=input_shape)
      lstm_out = LSTM(units=50, return_sequences=False)(inputs)
      outputs = Dense(1)(lstm_out)
      model = Model(inputs=inputs, outputs=outputs)
      model.compile(optimizer='adam', loss='mse') # Compile is needed even for a dummy model
      return model

    # Assuming your model expects (60, 1) input shape
    dummy_input_shape = (60, 1)
    model = build_dummy_model(dummy_input_shape)
    print("Dummy model and scaler created.")


# Dummy pivot centers (replace with your actual calculated pivot centers)
try:
    if 'center_prices' in globals() and len(center_prices) > 0:
        print("Using existing 'center_prices' for pivot points.")
        pivot_centers = center_prices
    else:
         raise NameError("Pivot points 'center_prices' not found.")
except NameError:
    print("Could not use existing 'center_prices'.")
    print("Creating dummy pivot centers for demonstration. REPLACE THIS.")
    # Create dummy pivot centers for demonstration
    pivot_centers = np.array([150, 200, 250, 300, 350]) # Example pivot levels
    print("Dummy pivot centers created.")


In [None]:
# Fetch the latest data required for the observation window
# Fetch enough data to cover the window size
INSTRUMENT = 'AAPL' # Use your instrument
WINDOW_SIZE = 60
data_period = f'{WINDOW_SIZE+5}d' # Fetch a bit more data than the window size
latest_data = yf.download(INSTRUMENT, period=data_period, interval='1d')

if latest_data.empty:
    print(f"Could not fetch latest data for {INSTRUMENT}. Cannot make a decision.")
else:
    # Ensure the data is sorted by date
    latest_data.sort_index(inplace=True)

    # Make the live decision
    decision = make_live_decision(latest_data, model, scaler, pivot_centers)
    print(f"\nLive Trading Decision for {INSTRUMENT}: {decision}")

