# TCN-GNN-LSTM Hybrid Architecture: Complete Implementation

## A Full End-to-End Deep Learning Pipeline for Multi-Asset Crypto Portfolio Optimization

---

### Table of Contents

1. **Environment Setup & Library Imports**
2. **Data Loading Pipeline**
   - 2.1 Exchange Configuration
   - 2.2 Multi-Asset OHLCV Fetching
   - 2.3 Data Quality Checks
3. **Exploratory Data Analysis (EDA)**
   - 3.1 Price Visualization
   - 3.2 Returns Distribution
   - 3.3 Correlation Analysis
   - 3.4 Volatility Analysis
4. **Feature Engineering**
   - 4.1 Technical Indicators
   - 4.2 Multi-Timeframe Features
   - 4.3 Cross-Asset Features
5. **Data Preprocessing**
   - 5.1 Scaling & Normalization
   - 5.2 Sequence Creation
   - 5.3 Train/Val/Test Split
6. **Model Architecture**
   - 6.1 TCN Feature Extractor
   - 6.2 Graph Neural Network (GNN)
   - 6.3 LSTM Processor
   - 6.4 Multi-Head Output
   - 6.5 Complete Model Assembly
7. **Loss Functions**
   - 7.1 Trading Loss (Sharpe-based)
   - 7.2 Prediction Loss (Gaussian NLL)
   - 7.3 Combined Multi-Task Loss
8. **Training Pipeline**
   - 8.1 Curriculum Learning Strategy
   - 8.2 Training Loop
   - 8.3 Validation & Monitoring
9. **Evaluation & Backtesting**
   - 9.1 Performance Metrics
   - 9.2 Uncertainty Analysis
   - 9.3 Portfolio Simulation
10. **Conclusion & Next Steps**

---

**Author:** AI Trading Research Team  
**Version:** 2.0 (Complete Implementation)  
**Last Updated:** February 2026

---

## 1. Environment Setup & Library Imports

First, let's import all necessary libraries and configure the environment.

In [None]:
# ============================================
# CORE LIBRARIES
# ============================================
import numpy as np
import pandas as pd
from datetime import datetime, timezone, timedelta
import warnings
warnings.filterwarnings('ignore')

# ============================================
# VISUALIZATION
# ============================================
import matplotlib.pyplot as plt
import seaborn as sns

# Set plotting style
plt.style.use('dark_background')
plt.rcParams['figure.figsize'] = (14, 6)
plt.rcParams['font.size'] = 10
plt.rcParams['axes.grid'] = True
plt.rcParams['grid.alpha'] = 0.3

# ============================================
# DATA FETCHING
# ============================================
import ccxt

# ============================================
# TECHNICAL ANALYSIS
# ============================================
import ta

# ============================================
# MACHINE LEARNING
# ============================================
from sklearn.preprocessing import RobustScaler, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, Model
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# ============================================
# OPTIMIZATION
# ============================================
from scipy.optimize import minimize
from scipy import stats

# ============================================
# UTILITIES
# ============================================
import json
import os
from typing import Dict, List, Tuple, Optional
import time

# Print environment info
print("="*70)
print("TCN-GNN-LSTM HYBRID ARCHITECTURE - COMPLETE IMPLEMENTATION")
print("="*70)
print(f"TensorFlow Version: {tf.__version__}")
print(f"NumPy Version:      {np.__version__}")
print(f"Pandas Version:     {pd.__version__}")
print(f"CCXT Version:       {ccxt.__version__}")
print(f"Timestamp:          {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S UTC')}")
print("="*70)

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"\nâœ“ GPU Available: {len(gpus)} device(s)")
    for gpu in gpus:
        print(f"  - {gpu.name}")
else:
    print("\nâš  No GPU detected. Training will use CPU.")

---

## 2. Data Loading Pipeline

### 2.1 Exchange Configuration

We use CCXT library to fetch data from cryptocurrency exchanges. This provides a unified API for 100+ exchanges.

In [None]:
class CryptoDataLoader:
    """
    Data loader for cryptocurrency OHLCV data from multiple exchanges.
    
    Features:
    - Multi-exchange support (Binance, Coinbase, Kraken, etc.)
    - Chunked fetching for large date ranges
    - Rate limit handling
    - Data caching
    """
    
    def __init__(self, exchange_name='binance'):
        """
        Initialize the data loader.
        
        Args:
            exchange_name: Name of the exchange ('binance', 'coinbase', 'kraken')
        """
        self.exchange_name = exchange_name
        self.exchange = self._init_exchange(exchange_name)
        self.data_cache = {}
        
    def _init_exchange(self, name):
        """Initialize exchange connection."""
        exchange_map = {
            'binance': ccxt.binanceus,
            'coinbase': ccxt.coinbase,
            'kraken': ccxt.kraken,
            'kucoin': ccxt.kucoin,
        }
        
        exchange_class = exchange_map.get(name, ccxt.binanceus)
        return exchange_class({
            'enableRateLimit': True,
            'options': {'defaultType': 'spot'}
        })
    
    def get_timeframe_ms(self, timeframe: str) -> int:
        """Convert timeframe string to milliseconds."""
        timeframe_map = {
            '1m': 60 * 1000,
            '5m': 5 * 60 * 1000,
            '15m': 15 * 60 * 1000,
            '30m': 30 * 60 * 1000,
            '1h': 60 * 60 * 1000,
            '4h': 4 * 60 * 60 * 1000,
            '1d': 24 * 60 * 60 * 1000,
        }
        return timeframe_map.get(timeframe, 60 * 60 * 1000)
    
    def fetch_ohlcv(self, symbol: str, timeframe: str = '1d', 
                    limit: int = 365, since: datetime = None,
                    until: datetime = None) -> pd.DataFrame:
        """
        Fetch OHLCV data with support for large date ranges.
        
        Args:
            symbol: Trading pair (e.g., 'BTC/USDT')
            timeframe: Candle timeframe ('1m', '5m', '1h', '4h', '1d')
            limit: Number of candles to fetch (if no date range)
            since: Start date for fetching
            until: End date for fetching
            
        Returns:
            DataFrame with OHLCV data
        """
        cache_key = f"{symbol}_{timeframe}_{limit}_{since}_{until}"
        if cache_key in self.data_cache:
            print(f"  Using cached data for {symbol}")
            return self.data_cache[cache_key]
        
        try:
            since_ms = int(since.timestamp() * 1000) if since else None
            until_ms = int(until.timestamp() * 1000) if until else int(datetime.now(timezone.utc).timestamp() * 1000)
            
            # Calculate timeframe in milliseconds
            tf_ms = self.get_timeframe_ms(timeframe)
            
            # Determine how many candles we need
            if since_ms and until_ms:
                total_needed = int((until_ms - since_ms) / tf_ms) + 1
            else:
                total_needed = limit
            
            # Fetch in chunks (max 1000 per request for most exchanges)
            all_data = []
            current_since = since_ms
            max_per_request = 1000
            fetched = 0
            
            while fetched < total_needed:
                fetch_limit = min(max_per_request, total_needed - fetched)
                
                try:
                    data = self.exchange.fetch_ohlcv(
                        symbol, timeframe, current_since, fetch_limit
                    )
                except Exception as e:
                    print(f"  Warning: Fetch error: {e}")
                    if all_data:
                        break
                    raise
                
                if not data:
                    break
                
                # Filter by end date
                if until_ms:
                    data = [d for d in data if d[0] <= until_ms]
                
                if not data:
                    break
                
                all_data.extend(data)
                fetched += len(data)
                
                # Check if we're done
                if len(data) < fetch_limit:
                    break
                
                if until_ms and data[-1][0] >= until_ms:
                    break
                
                # Move to next chunk
                current_since = data[-1][0] + 1
                time.sleep(0.1)  # Rate limiting
            
            if not all_data:
                print(f"  Warning: No data fetched for {symbol}")
                return pd.DataFrame()
            
            # Convert to DataFrame
            df = pd.DataFrame(all_data, columns=['timestamp', 'open', 'high', 'low', 'close', 'volume'])
            df['timestamp'] = pd.to_datetime(df['timestamp'], unit='ms', utc=True)
            df.set_index('timestamp', inplace=True)
            df = df[~df.index.duplicated(keep='first')]
            df = df.sort_index()
            
            # Cache the result
            self.data_cache[cache_key] = df
            
            return df
            
        except Exception as e:
            print(f"  Error fetching {symbol}: {e}")
            return pd.DataFrame()
    
    def fetch_multi_asset(self, symbols: List[str], timeframe: str = '1d',
                          limit: int = 365) -> Dict[str, pd.DataFrame]:
        """
        Fetch OHLCV data for multiple assets.
        
        Args:
            symbols: List of trading pairs
            timeframe: Candle timeframe
            limit: Number of candles per asset
            
        Returns:
            Dictionary mapping symbol -> DataFrame
        """
        data = {}
        print(f"\nFetching data for {len(symbols)} assets...")
        
        for i, symbol in enumerate(symbols):
            print(f"  [{i+1}/{len(symbols)}] {symbol}...", end=" ")
            df = self.fetch_ohlcv(symbol, timeframe, limit)
            if not df.empty:
                data[symbol] = df
                print(f"âœ“ {len(df)} candles")
            else:
                print("âœ— Failed")
        
        return data


# Initialize data loader
print("Initializing Data Loader...")
data_loader = CryptoDataLoader(exchange_name='binance')
print(f"Exchange: {data_loader.exchange.name}")
print(f"Rate Limit: {data_loader.exchange.enableRateLimit}")

### 2.2 Multi-Asset OHLCV Fetching

Let's define our portfolio assets and fetch historical data.

In [None]:
# ============================================
# PORTFOLIO CONFIGURATION
# ============================================

# Define assets for portfolio optimization
PORTFOLIO_ASSETS = [
    "BTC/USDT",   # Bitcoin - Market leader
    "ETH/USDT",   # Ethereum - Smart contracts
    "BNB/USDT",   # Binance Coin - Exchange token
    "SOL/USDT",   # Solana - High-performance
    "XRP/USDT",   # Ripple - Payments
]

# Data configuration
TIMEFRAME = '1d'       # Daily candles
LOOKBACK_DAYS = 365    # 1 year of data

print("\n" + "="*60)
print("PORTFOLIO CONFIGURATION")
print("="*60)
print(f"\nAssets ({len(PORTFOLIO_ASSETS)}):")
for i, asset in enumerate(PORTFOLIO_ASSETS, 1):
    print(f"  {i}. {asset}")
print(f"\nTimeframe: {TIMEFRAME}")
print(f"Lookback: {LOOKBACK_DAYS} days")

In [None]:
# ============================================
# FETCH DATA
# ============================================

print("\n" + "="*60)
print("FETCHING OHLCV DATA")
print("="*60)

# Fetch data for all assets
raw_data = data_loader.fetch_multi_asset(
    symbols=PORTFOLIO_ASSETS,
    timeframe=TIMEFRAME,
    limit=LOOKBACK_DAYS
)

print(f"\nâœ“ Successfully loaded data for {len(raw_data)} assets")

### 2.3 Data Quality Checks

In [None]:
# ============================================
# DATA QUALITY REPORT
# ============================================

print("\n" + "="*60)
print("DATA QUALITY REPORT")
print("="*60)

quality_report = []

for symbol, df in raw_data.items():
    report = {
        'Symbol': symbol.replace('/USDT', ''),
        'Rows': len(df),
        'Start': df.index.min().strftime('%Y-%m-%d'),
        'End': df.index.max().strftime('%Y-%m-%d'),
        'Missing': df.isnull().sum().sum(),
        'Min Price': f"${df['close'].min():.2f}",
        'Max Price': f"${df['close'].max():.2f}",
        'Current': f"${df['close'].iloc[-1]:.2f}"
    }
    quality_report.append(report)

quality_df = pd.DataFrame(quality_report)
print("\n", quality_df.to_string(index=False))

# Check for common date range
all_dates = [set(df.index) for df in raw_data.values()]
common_dates = set.intersection(*all_dates)
print(f"\nâœ“ Common data points across all assets: {len(common_dates)}")

---

## 3. Exploratory Data Analysis (EDA)

### 3.1 Price Visualization

In [None]:
# ============================================
# PRICE VISUALIZATION
# ============================================

fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

colors = ['#F7931A', '#627EEA', '#F3BA2F', '#00FFA3', '#23292F']

for idx, (symbol, df) in enumerate(raw_data.items()):
    if idx >= 5:
        break
    ax = axes[idx]
    
    # Plot price
    ax.plot(df.index, df['close'], color=colors[idx], linewidth=1.5, label='Close')
    ax.fill_between(df.index, df['low'], df['high'], alpha=0.2, color=colors[idx])
    
    # Add moving averages
    ma_20 = df['close'].rolling(window=20).mean()
    ma_50 = df['close'].rolling(window=50).mean()
    ax.plot(df.index, ma_20, '--', color='white', alpha=0.5, linewidth=1, label='MA20')
    ax.plot(df.index, ma_50, '--', color='yellow', alpha=0.5, linewidth=1, label='MA50')
    
    ax.set_title(f"{symbol.replace('/USDT', '')} Price", fontsize=12, fontweight='bold')
    ax.set_xlabel('Date')
    ax.set_ylabel('Price (USDT)')
    ax.legend(loc='upper left', fontsize=8)
    ax.tick_params(axis='x', rotation=45)

# Hide extra subplot
axes[5].set_visible(False)

plt.suptitle('Asset Price History with Moving Averages', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

### 3.2 Returns Distribution

In [None]:
# ============================================
# CALCULATE RETURNS
# ============================================

# Calculate daily returns for each asset
returns_data = {}
for symbol, df in raw_data.items():
    returns_data[symbol] = df['close'].pct_change().dropna()

# Create returns DataFrame
returns_df = pd.DataFrame(returns_data)
returns_df.columns = [s.replace('/USDT', '') for s in returns_df.columns]

print("\n" + "="*60)
print("RETURNS STATISTICS")
print("="*60)

# Calculate statistics
stats_df = pd.DataFrame({
    'Mean (%)': (returns_df.mean() * 100).round(3),
    'Std (%)': (returns_df.std() * 100).round(3),
    'Min (%)': (returns_df.min() * 100).round(2),
    'Max (%)': (returns_df.max() * 100).round(2),
    'Skewness': returns_df.skew().round(3),
    'Kurtosis': returns_df.kurtosis().round(3),
    'Sharpe (Ann.)': ((returns_df.mean() / returns_df.std()) * np.sqrt(365)).round(3)
})

print("\n", stats_df.to_string())

In [None]:
# ============================================
# RETURNS DISTRIBUTION PLOTS
# ============================================

fig, axes = plt.subplots(2, 3, figsize=(16, 10))
axes = axes.flatten()

for idx, col in enumerate(returns_df.columns):
    if idx >= 5:
        break
    ax = axes[idx]
    
    # Histogram with KDE
    returns_df[col].hist(bins=50, ax=ax, color=colors[idx], alpha=0.7, density=True)
    returns_df[col].plot.kde(ax=ax, color='white', linewidth=2)
    
    # Add normal distribution for comparison
    x = np.linspace(returns_df[col].min(), returns_df[col].max(), 100)
    normal = stats.norm.pdf(x, returns_df[col].mean(), returns_df[col].std())
    ax.plot(x, normal, '--', color='red', linewidth=1.5, label='Normal')
    
    ax.axvline(x=0, color='white', linestyle='--', alpha=0.5)
    ax.set_title(f"{col} Daily Returns Distribution", fontweight='bold')
    ax.set_xlabel('Return')
    ax.set_ylabel('Density')
    ax.legend()

axes[5].set_visible(False)
plt.suptitle('Returns Distribution Analysis', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

### 3.3 Correlation Analysis

**Key Insight for GNN:** This correlation matrix shows why we need a Graph Neural Network - assets are interconnected with time-varying relationships.

In [None]:
# ============================================
# CORRELATION ANALYSIS
# ============================================

# Calculate correlation matrix
correlation_matrix = returns_df.corr()

# Plot correlation heatmap
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Full period correlation
ax1 = axes[0]
mask = np.triu(np.ones_like(correlation_matrix, dtype=bool), k=1)
sns.heatmap(correlation_matrix, mask=mask, annot=True, fmt='.2f', 
            cmap='RdYlGn', center=0, ax=ax1,
            vmin=-1, vmax=1, square=True,
            linewidths=0.5, cbar_kws={'shrink': 0.8})
ax1.set_title('Full Period Correlation Matrix', fontweight='bold', fontsize=12)

# Rolling correlation (BTC vs others) - showing time-varying nature
ax2 = axes[1]
rolling_window = 30

for col in returns_df.columns[1:]:  # Skip BTC itself
    rolling_corr = returns_df['BTC'].rolling(window=rolling_window).corr(returns_df[col])
    ax2.plot(rolling_corr.index, rolling_corr, label=f'BTC-{col}', linewidth=1.5)

ax2.axhline(y=0, color='white', linestyle='--', alpha=0.3)
ax2.axhline(y=0.5, color='green', linestyle='--', alpha=0.3)
ax2.axhline(y=-0.5, color='red', linestyle='--', alpha=0.3)
ax2.set_title(f'{rolling_window}-Day Rolling Correlation with BTC', fontweight='bold', fontsize=12)
ax2.set_xlabel('Date')
ax2.set_ylabel('Correlation')
ax2.legend(loc='lower left')
ax2.set_ylim(-1, 1)

plt.tight_layout()
plt.show()

print("\nðŸ“Š Key Insight: Rolling correlations vary significantly over time!")
print("   â†’ This is why we need a GNN with dynamic edge weights")
print("   â†’ Static correlation matrices miss regime changes")

### 3.4 Volatility Analysis

In [None]:
# ============================================
# VOLATILITY ANALYSIS
# ============================================

# Calculate rolling volatility (annualized)
volatility_window = 30
rolling_vol = returns_df.rolling(window=volatility_window).std() * np.sqrt(365) * 100

fig, ax = plt.subplots(figsize=(14, 6))

for idx, col in enumerate(rolling_vol.columns):
    ax.plot(rolling_vol.index, rolling_vol[col], label=col, color=colors[idx], linewidth=1.5)

ax.set_title(f'{volatility_window}-Day Rolling Annualized Volatility (%)', fontweight='bold', fontsize=12)
ax.set_xlabel('Date')
ax.set_ylabel('Volatility (%)')
ax.legend(loc='upper right')
ax.axhline(y=50, color='yellow', linestyle='--', alpha=0.5, label='50% threshold')
ax.axhline(y=100, color='red', linestyle='--', alpha=0.5, label='100% threshold')

plt.tight_layout()
plt.show()

print("\nðŸ“Š Volatility Insights:")
print(f"   - Average volatility: {rolling_vol.mean().mean():.1f}%")
print(f"   - Max volatility spike: {rolling_vol.max().max():.1f}%")
print("   â†’ High volatility clustering suggests GARCH effects")
print("   â†’ TCN can capture multi-scale volatility patterns")

---

## 4. Feature Engineering

### 4.1 Technical Indicators

In [None]:
def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add comprehensive technical indicators to OHLCV data.
    
    Categories:
    - Trend indicators (MA, EMA, MACD)
    - Momentum indicators (RSI, Stochastic, Williams %R)
    - Volatility indicators (Bollinger Bands, ATR)
    - Volume indicators (OBV, MFI, VWAP)
    
    Args:
        df: DataFrame with OHLCV columns
        
    Returns:
        DataFrame with added technical indicators
    """
    df = df.copy()
    
    # ========== TREND INDICATORS ==========
    # Simple Moving Averages
    for period in [7, 14, 21, 50, 100]:
        df[f'sma_{period}'] = df['close'].rolling(window=period).mean()
        df[f'ema_{period}'] = df['close'].ewm(span=period, adjust=False).mean()
    
    # MACD
    macd = ta.trend.MACD(df['close'])
    df['macd'] = macd.macd()
    df['macd_signal'] = macd.macd_signal()
    df['macd_diff'] = macd.macd_diff()
    
    # ADX (Average Directional Index)
    adx = ta.trend.ADXIndicator(df['high'], df['low'], df['close'])
    df['adx'] = adx.adx()
    df['adx_pos'] = adx.adx_pos()
    df['adx_neg'] = adx.adx_neg()
    
    # ========== MOMENTUM INDICATORS ==========
    # RSI
    for period in [7, 14, 21]:
        df[f'rsi_{period}'] = ta.momentum.RSIIndicator(df['close'], window=period).rsi()
    
    # Stochastic
    stoch = ta.momentum.StochasticOscillator(df['high'], df['low'], df['close'])
    df['stoch_k'] = stoch.stoch()
    df['stoch_d'] = stoch.stoch_signal()
    
    # Williams %R
    df['williams_r'] = ta.momentum.WilliamsRIndicator(df['high'], df['low'], df['close']).williams_r()
    
    # ROC (Rate of Change)
    for period in [5, 10, 20]:
        df[f'roc_{period}'] = ta.momentum.ROCIndicator(df['close'], window=period).roc()
    
    # ========== VOLATILITY INDICATORS ==========
    # Bollinger Bands
    bb = ta.volatility.BollingerBands(df['close'])
    df['bb_high'] = bb.bollinger_hband()
    df['bb_low'] = bb.bollinger_lband()
    df['bb_mid'] = bb.bollinger_mavg()
    df['bb_width'] = bb.bollinger_wband()
    df['bb_pband'] = bb.bollinger_pband()
    
    # ATR (Average True Range)
    for period in [7, 14, 21]:
        df[f'atr_{period}'] = ta.volatility.AverageTrueRange(df['high'], df['low'], df['close'], window=period).average_true_range()
    
    # ========== VOLUME INDICATORS ==========
    # OBV (On-Balance Volume)
    df['obv'] = ta.volume.OnBalanceVolumeIndicator(df['close'], df['volume']).on_balance_volume()
    
    # MFI (Money Flow Index)
    df['mfi'] = ta.volume.MFIIndicator(df['high'], df['low'], df['close'], df['volume']).money_flow_index()
    
    # Volume SMA
    df['volume_sma_20'] = df['volume'].rolling(window=20).mean()
    df['volume_ratio'] = df['volume'] / df['volume_sma_20']
    
    # ========== PRICE FEATURES ==========
    # Returns at different horizons
    for period in [1, 3, 5, 10, 20]:
        df[f'return_{period}d'] = df['close'].pct_change(period)
    
    # Price position within range
    df['high_low_ratio'] = (df['close'] - df['low']) / (df['high'] - df['low'] + 1e-8)
    
    # Gap features
    df['gap'] = (df['open'] - df['close'].shift(1)) / df['close'].shift(1)
    
    return df


# Apply technical indicators to all assets
print("\n" + "="*60)
print("FEATURE ENGINEERING")
print("="*60)

featured_data = {}
for symbol, df in raw_data.items():
    print(f"  Processing {symbol}...", end=" ")
    featured_data[symbol] = add_technical_indicators(df)
    print(f"âœ“ {len(featured_data[symbol].columns)} features")

# Show feature list for one asset
sample_features = featured_data[PORTFOLIO_ASSETS[0]].columns.tolist()
print(f"\nâœ“ Total features per asset: {len(sample_features)}")
print(f"\nFeature categories:")
print(f"  - OHLCV: 5")
print(f"  - Moving Averages: 10")
print(f"  - MACD: 3")
print(f"  - ADX: 3")
print(f"  - RSI: 3")
print(f"  - Stochastic: 2")
print(f"  - Bollinger: 5")
print(f"  - ATR: 3")
print(f"  - Volume: 4")
print(f"  - Returns/Price: ~10")

---

## 5. Data Preprocessing

### 5.1 Scaling & Normalization

In [None]:
class DataPreprocessor:
    """
    Preprocess multi-asset data for TCN-GNN-LSTM model.
    
    Features:
    - Robust scaling (handles outliers)
    - NaN handling
    - Sequence creation for temporal models
    - Train/val/test splitting
    """
    
    def __init__(self):
        self.scalers = {}
        self.feature_names = None
        
    def prepare_multi_asset_data(self, data: Dict[str, pd.DataFrame], 
                                  target_col: str = 'return_1d') -> Tuple[np.ndarray, np.ndarray, List[str]]:
        """
        Prepare multi-asset data for model training.
        
        Args:
            data: Dictionary of DataFrames (symbol -> df)
            target_col: Column to use as target (return)
            
        Returns:
            X: (time, num_assets, features) array
            y: (time, num_assets) array of returns
            asset_names: List of asset names
        """
        # Find common dates
        common_index = None
        for df in data.values():
            if common_index is None:
                common_index = df.index
            else:
                common_index = common_index.intersection(df.index)
        
        asset_names = [s.replace('/USDT', '') for s in data.keys()]
        num_assets = len(data)
        
        # Get feature columns (exclude OHLCV base columns for features)
        exclude_cols = ['open', 'high', 'low', 'close', 'volume']
        first_df = list(data.values())[0]
        feature_cols = [c for c in first_df.columns if c not in exclude_cols]
        self.feature_names = feature_cols
        num_features = len(feature_cols)
        
        print(f"\n  Common data points: {len(common_index)}")
        print(f"  Assets: {num_assets}")
        print(f"  Features per asset: {num_features}")
        
        # Create arrays
        X_list = []
        y_list = []
        
        for symbol, df in data.items():
            df_aligned = df.loc[common_index].copy()
            
            # Scale features
            scaler = RobustScaler()
            features = df_aligned[feature_cols].values
            
            # Handle NaN
            features = np.nan_to_num(features, nan=0.0, posinf=0.0, neginf=0.0)
            features = scaler.fit_transform(features)
            
            self.scalers[symbol] = scaler
            X_list.append(features)
            
            # Target: next day return (shifted by 1)
            returns = df_aligned[target_col].shift(-1).values
            returns = np.nan_to_num(returns, nan=0.0)
            y_list.append(returns)
        
        # Stack to (time, num_assets, features)
        X = np.stack(X_list, axis=1)  # (time, assets, features)
        y = np.stack(y_list, axis=1)  # (time, assets)
        
        # Remove last row (NaN target)
        X = X[:-1]
        y = y[:-1]
        
        return X, y, asset_names
    
    def create_sequences(self, X: np.ndarray, y: np.ndarray, 
                         seq_length: int = 60) -> Tuple[np.ndarray, np.ndarray]:
        """
        Create sequences for temporal models.
        
        Args:
            X: (time, num_assets, features) array
            y: (time, num_assets) array
            seq_length: Length of input sequences
            
        Returns:
            X_seq: (samples, seq_length, num_assets, features)
            y_seq: (samples, num_assets)
        """
        X_seq = []
        y_seq = []
        
        for i in range(seq_length, len(X)):
            X_seq.append(X[i-seq_length:i])
            y_seq.append(y[i])
        
        return np.array(X_seq), np.array(y_seq)
    
    def train_val_test_split(self, X: np.ndarray, y: np.ndarray,
                              train_ratio: float = 0.7,
                              val_ratio: float = 0.15) -> Dict:
        """
        Split data chronologically (no shuffling for time series).
        
        Args:
            X, y: Input and target arrays
            train_ratio: Proportion for training
            val_ratio: Proportion for validation
            
        Returns:
            Dictionary with train/val/test splits
        """
        n = len(X)
        train_end = int(n * train_ratio)
        val_end = int(n * (train_ratio + val_ratio))
        
        return {
            'X_train': X[:train_end],
            'y_train': y[:train_end],
            'X_val': X[train_end:val_end],
            'y_val': y[train_end:val_end],
            'X_test': X[val_end:],
            'y_test': y[val_end:]
        }


# Preprocess data
print("\n" + "="*60)
print("DATA PREPROCESSING")
print("="*60)

preprocessor = DataPreprocessor()

# Prepare multi-asset data
X, y, asset_names = preprocessor.prepare_multi_asset_data(featured_data)
print(f"\n  X shape: {X.shape} (time, assets, features)")
print(f"  y shape: {y.shape} (time, assets)")

# Create sequences
SEQ_LENGTH = 60  # 60 days of history
X_seq, y_seq = preprocessor.create_sequences(X, y, seq_length=SEQ_LENGTH)
print(f"\n  X_seq shape: {X_seq.shape} (samples, seq_len, assets, features)")
print(f"  y_seq shape: {y_seq.shape} (samples, assets)")

# Split data
data_splits = preprocessor.train_val_test_split(X_seq, y_seq)
print(f"\n  Train: {len(data_splits['X_train'])} samples")
print(f"  Val:   {len(data_splits['X_val'])} samples")
print(f"  Test:  {len(data_splits['X_test'])} samples")

---

## 6. Model Architecture

### 6.1 TCN Feature Extractor

In [None]:
class TCNBlock(layers.Layer):
    """
    Temporal Convolutional Network Block.
    
    Features:
    - Dilated causal convolution (no future information leakage)
    - Residual connection
    - Layer normalization
    - GELU activation
    """
    
    def __init__(self, filters, kernel_size=3, dilation_rate=1, dropout=0.2, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.dilation_rate = dilation_rate
        self.dropout_rate = dropout
        
        # First conv block
        self.conv1 = layers.Conv1D(
            filters=filters,
            kernel_size=kernel_size,
            dilation_rate=dilation_rate,
            padding='causal',
            kernel_initializer='he_normal'
        )
        self.norm1 = layers.LayerNormalization()
        self.dropout1 = layers.Dropout(dropout)
        
        # Second conv block
        self.conv2 = layers.Conv1D(
            filters=filters,
            kernel_size=kernel_size,
            dilation_rate=dilation_rate,
            padding='causal',
            kernel_initializer='he_normal'
        )
        self.norm2 = layers.LayerNormalization()
        self.dropout2 = layers.Dropout(dropout)
        
        # Residual projection
        self.residual_conv = layers.Conv1D(filters, 1)
        
    def call(self, x, training=False):
        # First conv block
        out = self.conv1(x)
        out = self.norm1(out)
        out = tf.nn.gelu(out)
        out = self.dropout1(out, training=training)
        
        # Second conv block
        out = self.conv2(out)
        out = self.norm2(out)
        out = tf.nn.gelu(out)
        out = self.dropout2(out, training=training)
        
        # Residual connection
        residual = self.residual_conv(x)
        return out + residual
    
    def get_config(self):
        config = super().get_config()
        config.update({
            'filters': self.filters,
            'kernel_size': self.kernel_size,
            'dilation_rate': self.dilation_rate,
            'dropout': self.dropout_rate
        })
        return config


class TCNFeatureExtractor(layers.Layer):
    """
    Multi-scale TCN feature extractor.
    
    Uses exponentially increasing dilation rates to capture
    patterns at multiple time scales (1, 2, 4, 8, 16 days).
    """
    
    def __init__(self, num_channels=64, kernel_size=3, num_layers=4, dropout=0.2, **kwargs):
        super().__init__(**kwargs)
        self.num_channels = num_channels
        
        # TCN blocks with increasing dilation
        self.tcn_blocks = [
            TCNBlock(
                filters=num_channels,
                kernel_size=kernel_size,
                dilation_rate=2**i,
                dropout=dropout,
                name=f'tcn_block_{i}'
            )
            for i in range(num_layers)
        ]
        
        # Calculate receptive field
        self.receptive_field = 1 + (kernel_size - 1) * sum(2**i for i in range(num_layers))
        
    def call(self, x, training=False):
        """Process input through TCN blocks."""
        for block in self.tcn_blocks:
            x = block(x, training=training)
        return x


# Test TCN
print("\n" + "="*60)
print("TCN FEATURE EXTRACTOR")
print("="*60)

tcn = TCNFeatureExtractor(num_channels=64, num_layers=4)
test_input = tf.random.normal((32, 60, 50))  # (batch, time, features)
test_output = tcn(test_input)

print(f"\n  Input shape:  {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"  Receptive field: {tcn.receptive_field} timesteps")
print(f"  Dilation rates: [1, 2, 4, 8]")

### 6.2 Graph Neural Network (GNN)

In [None]:
class GraphAttentionLayer(layers.Layer):
    """
    Graph Attention Network layer for modeling dynamic asset relationships.
    
    Key features:
    - Multi-head attention
    - Dynamic edge weights (learned per timestep)
    - Residual connections
    """
    
    def __init__(self, hidden_dim, num_heads=4, dropout=0.1, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        self.num_heads = num_heads
        self.head_dim = hidden_dim // num_heads
        
        # Query, Key, Value projections
        self.query = layers.Dense(hidden_dim, kernel_initializer='glorot_uniform')
        self.key = layers.Dense(hidden_dim, kernel_initializer='glorot_uniform')
        self.value = layers.Dense(hidden_dim, kernel_initializer='glorot_uniform')
        
        # Output projection
        self.output_proj = layers.Dense(hidden_dim)
        self.dropout = layers.Dropout(dropout)
        self.norm = layers.LayerNormalization()
        
    def call(self, x, training=False, return_attention=False):
        """
        Process node features through graph attention.
        
        Args:
            x: (batch, time, num_assets, features)
            
        Returns:
            Updated features and optionally attention weights
        """
        batch_size = tf.shape(x)[0]
        time_steps = tf.shape(x)[1]
        num_assets = tf.shape(x)[2]
        
        # Reshape for attention: (batch * time, assets, features)
        x_reshaped = tf.reshape(x, [-1, num_assets, self.hidden_dim])
        
        # Compute Q, K, V
        Q = self.query(x_reshaped)
        K = self.key(x_reshaped)
        V = self.value(x_reshaped)
        
        # Reshape for multi-head: (batch*time, heads, assets, head_dim)
        Q = tf.reshape(Q, [-1, num_assets, self.num_heads, self.head_dim])
        K = tf.reshape(K, [-1, num_assets, self.num_heads, self.head_dim])
        V = tf.reshape(V, [-1, num_assets, self.num_heads, self.head_dim])
        
        Q = tf.transpose(Q, [0, 2, 1, 3])  # (B*T, heads, assets, dim)
        K = tf.transpose(K, [0, 2, 1, 3])
        V = tf.transpose(V, [0, 2, 1, 3])
        
        # Attention scores
        scores = tf.matmul(Q, K, transpose_b=True) / tf.math.sqrt(tf.cast(self.head_dim, tf.float32))
        attention_weights = tf.nn.softmax(scores, axis=-1)
        attention_weights = self.dropout(attention_weights, training=training)
        
        # Apply attention
        attended = tf.matmul(attention_weights, V)
        
        # Reshape back
        attended = tf.transpose(attended, [0, 2, 1, 3])
        attended = tf.reshape(attended, [-1, num_assets, self.hidden_dim])
        
        # Output projection + residual
        output = self.output_proj(attended)
        output = self.norm(output + x_reshaped)
        
        # Reshape to original
        output = tf.reshape(output, [batch_size, time_steps, num_assets, self.hidden_dim])
        
        if return_attention:
            attention_weights = tf.reshape(attention_weights, 
                                           [batch_size, time_steps, self.num_heads, num_assets, num_assets])
            return output, attention_weights
        return output


# Test GNN
print("\n" + "="*60)
print("GRAPH ATTENTION NETWORK")
print("="*60)

gnn = GraphAttentionLayer(hidden_dim=64, num_heads=4)
test_input = tf.random.normal((32, 60, 5, 64))  # (batch, time, assets, features)
test_output, attn = gnn(test_input, return_attention=True)

print(f"\n  Input shape:      {test_input.shape}")
print(f"  Output shape:     {test_output.shape}")
print(f"  Attention shape:  {attn.shape} (batch, time, heads, assets, assets)")
print(f"\n  â†’ Attention weights show how each asset relates to others")
print(f"  â†’ Weights are dynamic (computed per timestep)")

### 6.3 LSTM Processor

In [None]:
class AssetLSTMProcessor(layers.Layer):
    """
    LSTM processor for temporal sequence modeling per asset.
    
    Features:
    - Bidirectional LSTM
    - Temporal attention for weighting important timesteps
    - Processes each asset independently (GNN already handled cross-asset info)
    """
    
    def __init__(self, hidden_dim=128, num_layers=2, dropout=0.2, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        
        # Stacked Bidirectional LSTM
        self.lstm_layers = [
            layers.Bidirectional(
                layers.LSTM(hidden_dim, return_sequences=True, dropout=dropout),
                name=f'bilstm_{i}'
            )
            for i in range(num_layers)
        ]
        
        # Temporal attention
        self.attention = layers.Dense(1, activation='tanh')
        self.final_norm = layers.LayerNormalization()
        
    def call(self, x, training=False):
        """
        Process temporal sequences.
        
        Args:
            x: (batch, time, num_assets, features)
            
        Returns:
            (batch, num_assets, 2*hidden_dim)
        """
        batch_size = tf.shape(x)[0]
        time_steps = tf.shape(x)[1]
        num_assets = tf.shape(x)[2]
        features = tf.shape(x)[3]
        
        # Reshape to process each asset: (batch * assets, time, features)
        x = tf.transpose(x, [0, 2, 1, 3])  # (batch, assets, time, features)
        x = tf.reshape(x, [-1, time_steps, features])
        
        # Apply LSTMs
        for lstm in self.lstm_layers:
            x = lstm(x, training=training)
        
        # Temporal attention
        attention_scores = self.attention(x)  # (B*N, T, 1)
        attention_weights = tf.nn.softmax(attention_scores, axis=1)
        context = tf.reduce_sum(x * attention_weights, axis=1)  # (B*N, 2*hidden)
        
        # Reshape back: (batch, assets, 2*hidden)
        context = tf.reshape(context, [batch_size, num_assets, -1])
        context = self.final_norm(context)
        
        return context


# Test LSTM
print("\n" + "="*60)
print("LSTM PROCESSOR")
print("="*60)

lstm_proc = AssetLSTMProcessor(hidden_dim=64, num_layers=2)
test_input = tf.random.normal((32, 60, 5, 64))
test_output = lstm_proc(test_input)

print(f"\n  Input shape:  {test_input.shape}")
print(f"  Output shape: {test_output.shape}")
print(f"\n  â†’ Output is 2x hidden_dim (bidirectional)")
print(f"  â†’ Temporal attention weights important timesteps")

### 6.4 Multi-Head Output

In [None]:
class MultiHeadOutput(layers.Layer):
    """
    Three-headed output layer:
    
    1. Trading Head: Portfolio weights (softmax)
    2. Prediction Head: Gaussian parameters (mu, sigma) for uncertainty
    3. Value Head: Expected cumulative return (for RL)
    """
    
    def __init__(self, num_assets, hidden_dim=128, **kwargs):
        super().__init__(**kwargs)
        self.num_assets = num_assets
        
        # Trading head
        self.trading_hidden = layers.Dense(hidden_dim, activation='relu')
        self.trading_output = layers.Dense(num_assets)
        
        # Prediction head
        self.pred_hidden = layers.Dense(hidden_dim, activation='relu')
        self.pred_mu = layers.Dense(num_assets)
        self.pred_log_sigma = layers.Dense(num_assets)
        
        # Value head
        self.value_hidden = layers.Dense(hidden_dim, activation='relu')
        self.value_output = layers.Dense(1)
        
    def call(self, x, training=False):
        """
        Generate multi-head outputs.
        
        Args:
            x: (batch, num_assets, features) from LSTM
            
        Returns:
            trading_weights, pred_mu, pred_sigma, value
        """
        # Global pooling for trading and value heads
        global_features = tf.reduce_mean(x, axis=1)  # (batch, features)
        
        # ===== TRADING HEAD =====
        trading_h = self.trading_hidden(global_features)
        trading_logits = self.trading_output(trading_h)
        trading_weights = tf.nn.softmax(trading_logits, axis=-1)
        
        # ===== PREDICTION HEAD =====
        pred_h = self.pred_hidden(x)  # (batch, assets, hidden)
        pred_mu = tf.reduce_mean(self.pred_mu(pred_h), axis=-1)  # (batch, assets)
        pred_log_sigma = tf.reduce_mean(self.pred_log_sigma(pred_h), axis=-1)
        pred_log_sigma = tf.clip_by_value(pred_log_sigma, -4.6, 2.3)  # sigma in [0.01, 10]
        pred_sigma = tf.exp(pred_log_sigma)
        
        # ===== VALUE HEAD =====
        value_h = self.value_hidden(global_features)
        value = self.value_output(value_h)
        
        return trading_weights, pred_mu, pred_sigma, value


# Test Multi-Head
print("\n" + "="*60)
print("MULTI-HEAD OUTPUT")
print("="*60)

multi_head = MultiHeadOutput(num_assets=5, hidden_dim=64)
test_input = tf.random.normal((32, 5, 128))
weights, mu, sigma, value = multi_head(test_input)

print(f"\n  Input shape:  {test_input.shape}")
print(f"\n  Outputs:")
print(f"    Trading weights: {weights.shape} (sum to 1)")
print(f"    Pred mean (mu):  {mu.shape}")
print(f"    Pred std (sigma):{sigma.shape}")
print(f"    Value estimate:  {value.shape}")
print(f"\n  Sample weights: {weights[0].numpy().round(3)}")
print(f"  Sum of weights: {weights[0].numpy().sum():.4f}")

### 6.5 Complete Model Assembly

In [None]:
class TCN_GNN_LSTM_Model(Model):
    """
    Complete TCN-GNN-LSTM Hybrid Model for Portfolio Optimization.
    
    Architecture:
    1. TCN: Multi-scale temporal feature extraction
    2. GNN: Dynamic cross-asset relationship modeling
    3. LSTM: Sequential processing with memory
    4. Multi-Head: Trading, Prediction, Value outputs
    """
    
    def __init__(self, 
                 num_assets,
                 input_features,
                 tcn_channels=64,
                 tcn_layers=4,
                 gnn_heads=4,
                 lstm_hidden=64,
                 lstm_layers=2,
                 dropout=0.2,
                 **kwargs):
        super().__init__(**kwargs)
        
        self.num_assets = num_assets
        self.input_features = input_features
        
        # Input projection (per asset)
        self.input_proj = layers.Dense(tcn_channels, activation='relu')
        
        # TCN Feature Extractor
        self.tcn = TCNFeatureExtractor(
            num_channels=tcn_channels,
            num_layers=tcn_layers,
            dropout=dropout
        )
        
        # Graph Attention Network
        self.gnn = GraphAttentionLayer(
            hidden_dim=tcn_channels,
            num_heads=gnn_heads,
            dropout=dropout
        )
        
        # LSTM Processor
        self.lstm = AssetLSTMProcessor(
            hidden_dim=lstm_hidden,
            num_layers=lstm_layers,
            dropout=dropout
        )
        
        # Multi-Head Output
        self.output_heads = MultiHeadOutput(
            num_assets=num_assets,
            hidden_dim=lstm_hidden * 2  # Bidirectional
        )
        
    def call(self, x, training=False, return_attention=False):
        """
        Forward pass.
        
        Args:
            x: (batch, time, num_assets, features)
            
        Returns:
            trading_weights, pred_mu, pred_sigma, value
        """
        batch_size = tf.shape(x)[0]
        time_steps = tf.shape(x)[1]
        num_assets = tf.shape(x)[2]
        
        # Project input features
        # Reshape to (batch * assets, time, features) for TCN
        x = tf.transpose(x, [0, 2, 1, 3])  # (batch, assets, time, features)
        x = tf.reshape(x, [-1, time_steps, self.input_features])
        x = self.input_proj(x)  # (batch * assets, time, tcn_channels)
        
        # TCN
        x = self.tcn(x, training=training)  # (batch * assets, time, tcn_channels)
        
        # Reshape for GNN: (batch, time, assets, features)
        x = tf.reshape(x, [batch_size, num_assets, time_steps, -1])
        x = tf.transpose(x, [0, 2, 1, 3])  # (batch, time, assets, features)
        
        # GNN
        if return_attention:
            x, attention = self.gnn(x, training=training, return_attention=True)
        else:
            x = self.gnn(x, training=training)
        
        # LSTM
        x = self.lstm(x, training=training)  # (batch, assets, 2*lstm_hidden)
        
        # Multi-Head Output
        outputs = self.output_heads(x, training=training)
        
        if return_attention:
            return outputs + (attention,)
        return outputs
    
    def get_config(self):
        return {
            'num_assets': self.num_assets,
            'input_features': self.input_features
        }


# Build and summarize model
print("\n" + "="*60)
print("COMPLETE TCN-GNN-LSTM MODEL")
print("="*60)

# Get dimensions from data
num_assets = len(PORTFOLIO_ASSETS)
seq_length = SEQ_LENGTH
num_features = X_seq.shape[-1]

print(f"\n  Configuration:")
print(f"    - Assets: {num_assets}")
print(f"    - Sequence length: {seq_length}")
print(f"    - Features per asset: {num_features}")

# Create model
model = TCN_GNN_LSTM_Model(
    num_assets=num_assets,
    input_features=num_features,
    tcn_channels=64,
    tcn_layers=4,
    gnn_heads=4,
    lstm_hidden=64,
    lstm_layers=2,
    dropout=0.2
)

# Test forward pass
test_input = tf.random.normal((4, seq_length, num_assets, num_features))
weights, mu, sigma, value = model(test_input)

print(f"\n  Forward pass test:")
print(f"    Input:  {test_input.shape}")
print(f"    Weights: {weights.shape}")
print(f"    Mu:     {mu.shape}")
print(f"    Sigma:  {sigma.shape}")
print(f"    Value:  {value.shape}")

# Count parameters
model.build(input_shape=(None, seq_length, num_assets, num_features))
total_params = model.count_params()
print(f"\n  Total parameters: {total_params:,}")

---

## 7. Loss Functions

### 7.1 Trading Loss (Sharpe-based)

In [None]:
def sharpe_ratio_loss(y_true, weights, epsilon=1e-8):
    """
    Negative Sharpe Ratio loss for portfolio optimization.
    
    Sharpe = mean(returns) / std(returns)
    
    We minimize -Sharpe to maximize actual Sharpe.
    """
    # Portfolio returns: weighted sum of asset returns
    portfolio_returns = tf.reduce_sum(weights * y_true, axis=-1)  # (batch,)
    
    # Mean and std
    mean_return = tf.reduce_mean(portfolio_returns)
    std_return = tf.math.reduce_std(portfolio_returns) + epsilon
    
    # Sharpe ratio (annualized for daily data)
    sharpe = mean_return / std_return * tf.math.sqrt(252.0)
    
    return -sharpe  # Negative for minimization


print("Trading Loss: Negative Sharpe Ratio")
print("====================================")
print("Formula: L = -Sharpe(portfolio_returns)")
print("Goal: Maximize risk-adjusted returns")

### 7.2 Prediction Loss (Gaussian NLL)

In [None]:
def gaussian_nll_loss(y_true, mu, sigma, epsilon=1e-8):
    """
    Gaussian Negative Log-Likelihood loss.
    
    This loss:
    1. Penalizes wrong predictions
    2. Rewards confident correct predictions
    3. Penalizes overconfident wrong predictions
    
    NLL = 0.5 * [log(sigma^2) + (y - mu)^2 / sigma^2]
    """
    sigma = tf.maximum(sigma, epsilon)
    variance = tf.square(sigma)
    squared_error = tf.square(y_true - mu)
    
    nll = 0.5 * (tf.math.log(variance) + squared_error / variance)
    
    return tf.reduce_mean(nll)


print("Prediction Loss: Gaussian NLL")
print("==============================")
print("Formula: L = 0.5 * [log(ÏƒÂ²) + (y - Î¼)Â² / ÏƒÂ²]")
print("Goal: Learn to predict returns WITH uncertainty")

### 7.3 Combined Multi-Task Loss

In [None]:
class MultiTaskLoss:
    """
    Combined loss for multi-head TCN-GNN-LSTM model.
    
    L_total = Î»1 * L_trading + Î»2 * L_prediction + Î»3 * L_value
    
    The prediction loss acts as a regularizer to prevent
    overfitting on the trading objective.
    """
    
    def __init__(self, lambda_trading=1.0, lambda_pred=0.1, lambda_value=0.0):
        self.lambda_trading = lambda_trading
        self.lambda_pred = lambda_pred
        self.lambda_value = lambda_value
        
    def __call__(self, y_true, weights, mu, sigma, value):
        """
        Compute combined loss.
        
        Args:
            y_true: Actual returns (batch, assets)
            weights: Predicted portfolio weights
            mu: Predicted mean returns
            sigma: Predicted uncertainty
            value: Predicted value estimate
        """
        losses = {}
        
        # Trading loss
        if self.lambda_trading > 0:
            losses['trading'] = sharpe_ratio_loss(y_true, weights)
        
        # Prediction loss
        if self.lambda_pred > 0:
            losses['prediction'] = gaussian_nll_loss(y_true, mu, sigma)
        
        # Value loss
        if self.lambda_value > 0:
            portfolio_return = tf.reduce_sum(weights * y_true, axis=-1, keepdims=True)
            losses['value'] = tf.reduce_mean(tf.square(value - portfolio_return))
        
        # Weighted sum
        total = (
            self.lambda_trading * losses.get('trading', 0) +
            self.lambda_pred * losses.get('prediction', 0) +
            self.lambda_value * losses.get('value', 0)
        )
        
        return total, losses


print("\nMulti-Task Loss Configuration")
print("==============================")
print("Stage 1 (Representation): Î»_trading=0.0, Î»_pred=1.0, Î»_value=0.0")
print("Stage 2 (Trading):        Î»_trading=1.0, Î»_pred=0.1, Î»_value=0.0")
print("Stage 3 (RL):             Î»_trading=0.5, Î»_pred=0.05, Î»_value=0.45")

---

## 8. Training Pipeline

### 8.1 Curriculum Learning Strategy

In [None]:
class CurriculumTrainer:
    """
    Implements 3-stage curriculum learning:
    
    Stage 1: Learn representations via prediction task
    Stage 2: Fine-tune for trading with prediction regularizer
    Stage 3: (Optional) RL enhancement
    """
    
    def __init__(self, model, learning_rate=1e-3):
        self.model = model
        self.optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
        self.stage = 1
        self.loss_fn = MultiTaskLoss()
        self.history = {'stage': [], 'epoch': [], 'train_loss': [], 'val_loss': []}
        
    def set_stage(self, stage):
        """Configure loss weights for curriculum stage."""
        self.stage = stage
        
        if stage == 1:
            # Representation learning
            self.loss_fn = MultiTaskLoss(lambda_trading=0.0, lambda_pred=1.0, lambda_value=0.0)
            print("Stage 1: Representation Learning (prediction only)")
        elif stage == 2:
            # Trading fine-tuning
            self.loss_fn = MultiTaskLoss(lambda_trading=1.0, lambda_pred=0.1, lambda_value=0.0)
            print("Stage 2: Trading Fine-tuning (trading + prediction regularizer)")
        elif stage == 3:
            # RL enhancement
            self.loss_fn = MultiTaskLoss(lambda_trading=0.5, lambda_pred=0.05, lambda_value=0.45)
            print("Stage 3: RL Enhancement (all heads)")
    
    @tf.function
    def train_step(self, X, y):
        """Single training step."""
        with tf.GradientTape() as tape:
            weights, mu, sigma, value = self.model(X, training=True)
            total_loss, losses = self.loss_fn(y, weights, mu, sigma, value)
        
        gradients = tape.gradient(total_loss, self.model.trainable_variables)
        # Gradient clipping
        gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        
        return total_loss, losses
    
    @tf.function
    def val_step(self, X, y):
        """Validation step."""
        weights, mu, sigma, value = self.model(X, training=False)
        total_loss, losses = self.loss_fn(y, weights, mu, sigma, value)
        return total_loss, losses
    
    def train_epoch(self, train_data, val_data, batch_size=32):
        """Train for one epoch."""
        X_train, y_train = train_data
        X_val, y_val = val_data
        
        # Training
        train_losses = []
        num_batches = len(X_train) // batch_size
        
        for i in range(num_batches):
            start = i * batch_size
            end = start + batch_size
            X_batch = tf.constant(X_train[start:end], dtype=tf.float32)
            y_batch = tf.constant(y_train[start:end], dtype=tf.float32)
            
            loss, _ = self.train_step(X_batch, y_batch)
            train_losses.append(loss.numpy())
        
        # Validation
        X_val_tensor = tf.constant(X_val, dtype=tf.float32)
        y_val_tensor = tf.constant(y_val, dtype=tf.float32)
        val_loss, _ = self.val_step(X_val_tensor, y_val_tensor)
        
        return np.mean(train_losses), val_loss.numpy()
    
    def fit(self, train_data, val_data, epochs=50, batch_size=32, patience=10):
        """Full training loop with early stopping."""
        best_val_loss = float('inf')
        patience_counter = 0
        
        for epoch in range(epochs):
            train_loss, val_loss = self.train_epoch(train_data, val_data, batch_size)
            
            self.history['stage'].append(self.stage)
            self.history['epoch'].append(epoch)
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            
            print(f"  Epoch {epoch+1}/{epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}")
            
            # Early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"  Early stopping at epoch {epoch+1}")
                    break
        
        return self.history


print("\n" + "="*60)
print("CURRICULUM LEARNING TRAINER")
print("="*60)

### 8.2 Training Loop

In [None]:
# ============================================
# TRAINING
# ============================================

print("\n" + "="*60)
print("MODEL TRAINING")
print("="*60)

# Create fresh model
model = TCN_GNN_LSTM_Model(
    num_assets=num_assets,
    input_features=num_features,
    tcn_channels=32,  # Reduced for faster training
    tcn_layers=3,
    gnn_heads=2,
    lstm_hidden=32,
    lstm_layers=1,
    dropout=0.2
)

# Initialize trainer
trainer = CurriculumTrainer(model, learning_rate=1e-3)

# Prepare data tuples
train_data = (data_splits['X_train'], data_splits['y_train'])
val_data = (data_splits['X_val'], data_splits['y_val'])

# STAGE 1: Representation Learning
print("\n--- STAGE 1: Representation Learning ---")
trainer.set_stage(1)
history_s1 = trainer.fit(train_data, val_data, epochs=20, batch_size=16, patience=5)

# STAGE 2: Trading Fine-tuning
print("\n--- STAGE 2: Trading Fine-tuning ---")
trainer.set_stage(2)
trainer.optimizer.learning_rate.assign(5e-4)  # Lower LR for fine-tuning
history_s2 = trainer.fit(train_data, val_data, epochs=20, batch_size=16, patience=5)

print("\nâœ“ Training complete!")

### 8.3 Training Visualization

In [None]:
# ============================================
# TRAINING VISUALIZATION
# ============================================

history_df = pd.DataFrame(trainer.history)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Stage 1
s1_data = history_df[history_df['stage'] == 1]
if len(s1_data) > 0:
    axes[0].plot(s1_data['epoch'], s1_data['train_loss'], 'b-', label='Train', linewidth=2)
    axes[0].plot(s1_data['epoch'], s1_data['val_loss'], 'r--', label='Val', linewidth=2)
    axes[0].set_title('Stage 1: Representation Learning', fontweight='bold')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

# Stage 2
s2_data = history_df[history_df['stage'] == 2]
if len(s2_data) > 0:
    axes[1].plot(range(len(s2_data)), s2_data['train_loss'], 'b-', label='Train', linewidth=2)
    axes[1].plot(range(len(s2_data)), s2_data['val_loss'], 'r--', label='Val', linewidth=2)
    axes[1].set_title('Stage 2: Trading Fine-tuning', fontweight='bold')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Loss (Neg. Sharpe)')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

## 9. Evaluation & Backtesting

### 9.1 Performance Metrics

In [None]:
# ============================================
# MODEL EVALUATION
# ============================================

print("\n" + "="*60)
print("MODEL EVALUATION")
print("="*60)

# Get predictions on test set
X_test = tf.constant(data_splits['X_test'], dtype=tf.float32)
y_test = data_splits['y_test']

weights, mu, sigma, value = model(X_test, training=False)
weights = weights.numpy()
mu = mu.numpy()
sigma = sigma.numpy()

# Calculate portfolio returns
portfolio_returns = np.sum(weights * y_test, axis=1)

# Calculate metrics
total_return = (1 + portfolio_returns).prod() - 1
mean_return = portfolio_returns.mean()
std_return = portfolio_returns.std()
sharpe = mean_return / std_return * np.sqrt(252) if std_return > 0 else 0

# Max drawdown
cumulative = (1 + portfolio_returns).cumprod()
running_max = np.maximum.accumulate(cumulative)
drawdown = (cumulative - running_max) / running_max
max_drawdown = drawdown.min()

print(f"\n  Test Period Performance:")
print(f"  ========================")
print(f"  Total Return:    {total_return*100:.2f}%")
print(f"  Mean Daily:      {mean_return*100:.4f}%")
print(f"  Volatility:      {std_return*100:.4f}%")
print(f"  Sharpe Ratio:    {sharpe:.3f}")
print(f"  Max Drawdown:    {max_drawdown*100:.2f}%")

### 9.2 Uncertainty Analysis

In [None]:
# ============================================
# UNCERTAINTY ANALYSIS
# ============================================

print("\n" + "="*60)
print("UNCERTAINTY ANALYSIS")
print("="*60)

# Calculate confidence (inverse of sigma)
confidence = 1 / (1 + sigma)
mean_confidence_per_asset = confidence.mean(axis=0)

print(f"\n  Confidence per Asset:")
for i, name in enumerate(asset_names):
    print(f"    {name}: {mean_confidence_per_asset[i]:.3f}")

# Plot uncertainty over time
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Average sigma over test period
ax1 = axes[0]
for i, name in enumerate(asset_names):
    ax1.plot(sigma[:, i], label=name, alpha=0.7)
ax1.set_title('Prediction Uncertainty (Ïƒ) Over Time', fontweight='bold')
ax1.set_xlabel('Test Sample')
ax1.set_ylabel('Sigma')
ax1.legend()

# Confidence vs. Prediction Error
ax2 = axes[1]
prediction_errors = np.abs(y_test - mu)
avg_error = prediction_errors.mean(axis=1)
avg_confidence = confidence.mean(axis=1)

ax2.scatter(avg_confidence, avg_error, alpha=0.5, c='cyan')
ax2.set_title('Confidence vs. Prediction Error', fontweight='bold')
ax2.set_xlabel('Average Confidence')
ax2.set_ylabel('Average Absolute Error')

# Add trend line
z = np.polyfit(avg_confidence, avg_error, 1)
p = np.poly1d(z)
ax2.plot(np.sort(avg_confidence), p(np.sort(avg_confidence)), 'r--', linewidth=2)

plt.tight_layout()
plt.show()

print("\n  Key Insight: Lower confidence should correlate with higher errors")
print("  â†’ This means the model knows when it's uncertain!")

### 9.3 Portfolio Allocation Visualization

In [None]:
# ============================================
# PORTFOLIO ALLOCATION VISUALIZATION
# ============================================

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Average weights
ax1 = axes[0]
avg_weights = weights.mean(axis=0)
colors_pie = ['#F7931A', '#627EEA', '#F3BA2F', '#00FFA3', '#23292F']
wedges, texts, autotexts = ax1.pie(
    avg_weights, 
    labels=asset_names, 
    autopct='%1.1f%%',
    colors=colors_pie,
    explode=[0.02]*len(asset_names)
)
ax1.set_title('Average Portfolio Allocation', fontweight='bold')

# Weights over time
ax2 = axes[1]
ax2.stackplot(range(len(weights)), weights.T, labels=asset_names, colors=colors_pie, alpha=0.8)
ax2.set_title('Portfolio Allocation Over Time', fontweight='bold')
ax2.set_xlabel('Test Sample')
ax2.set_ylabel('Weight')
ax2.legend(loc='upper right')
ax2.set_ylim(0, 1)

plt.tight_layout()
plt.show()

### 9.4 Equity Curve

In [None]:
# ============================================
# EQUITY CURVE
# ============================================

fig, axes = plt.subplots(2, 1, figsize=(14, 8))

# Equity curve
ax1 = axes[0]
cumulative_returns = (1 + portfolio_returns).cumprod()

# Compare with equal-weight portfolio
equal_weights = np.ones((len(y_test), num_assets)) / num_assets
equal_returns = np.sum(equal_weights * y_test, axis=1)
equal_cumulative = (1 + equal_returns).cumprod()

ax1.plot(cumulative_returns, label='TCN-GNN-LSTM', linewidth=2, color='cyan')
ax1.plot(equal_cumulative, label='Equal Weight', linewidth=2, color='orange', linestyle='--')
ax1.axhline(y=1, color='white', linestyle=':', alpha=0.5)
ax1.set_title('Cumulative Returns Comparison', fontweight='bold')
ax1.set_xlabel('Test Sample')
ax1.set_ylabel('Cumulative Return')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Drawdown
ax2 = axes[1]
ax2.fill_between(range(len(drawdown)), drawdown * 100, 0, color='red', alpha=0.5)
ax2.set_title('Portfolio Drawdown', fontweight='bold')
ax2.set_xlabel('Test Sample')
ax2.set_ylabel('Drawdown (%)')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Final comparison
equal_total = (1 + equal_returns).prod() - 1
print(f"\n  Strategy Comparison:")
print(f"  =====================")
print(f"  TCN-GNN-LSTM:  {total_return*100:.2f}%")
print(f"  Equal Weight:  {equal_total*100:.2f}%")
print(f"  Outperformance: {(total_return - equal_total)*100:.2f}%")

---

## 10. Conclusion & Next Steps

### 10.1 Summary

In [None]:
print("\n" + "="*70)
print("TCN-GNN-LSTM ARCHITECTURE - IMPLEMENTATION COMPLETE")
print("="*70)

print("""
ARCHITECTURE COMPONENTS:
========================
1. TCN Feature Extractor
   - Multi-scale temporal patterns (dilations: 1, 2, 4, 8)
   - Causal convolutions (no future leakage)
   - Residual connections for gradient flow

2. Graph Neural Network
   - Dynamic cross-asset relationships
   - Multi-head attention mechanism
   - Time-varying correlation modeling

3. LSTM Processor
   - Bidirectional for richer context
   - Temporal attention for important timesteps
   - Sequential memory for long-term patterns

4. Multi-Head Output
   - Trading: Portfolio weights (softmax)
   - Prediction: Gaussian (mean + uncertainty)
   - Value: Expected return (for RL)

KEY INNOVATIONS:
================
âœ“ Uncertainty quantification via Gaussian head
âœ“ Curriculum learning (3 stages)
âœ“ Multi-task regularization
âœ“ Dynamic correlation modeling

NEXT STEPS:
===========
1. Increase model capacity for production
2. Add more assets (20+)
3. Implement Stage 3 (RL enhancement)
4. Real-time inference pipeline
5. Live trading integration
""")

print("="*70)