# Transformer-PPO Portfolio Optimization with Precious Metals

This notebook combines Transformer-PPO reinforcement learning for portfolio optimization with:
- **Dynamic NIFTY 50 stock selection** (top 15 by quarterly returns)
- **Precious metals ETFs**: GOLDBEES and SILVERBEES
- **17-asset portfolio** optimized using PPO

## Sections
1. Setup and Imports
2. Data Fetching (yfinance)
3. Dynamic Stock Selection (Quarterly Returns)
4. Feature Engineering (ATR, MFI, Technical Indicators)
5. Transformer-PPO Model Setup
6. Environment Configuration
7. Training Loop
8. Backtesting
9. Performance Visualization
10. Final Results & Asset Weights

## 1. Setup and Imports

In [None]:
# Install required packages
!pip install yfinance pandas-ta gymnasium torch numpy pandas matplotlib seaborn plotly tqdm scikit-learn --quiet

In [None]:
import sys
import os

# Add src to path for imports
sys.path.insert(0, os.path.abspath('../src'))

import numpy as np
import pandas as pd
import yfinance as yf
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
from datetime import datetime, timedelta
from tqdm import tqdm
import torch
import torch.nn as nn
from typing import List, Dict, Tuple, Optional

# Import transformer-PPO components
from models.ppo_agent import PPOAgent
from environment.trading_env import TradingEnvironment
from training.trainer import PPOTrainer
from data.features import FeatureEngineering
from data.preprocessing import FeaturePreprocessor
from backtesting.engine import BacktestEngine
from backtesting.metrics import PerformanceMetrics
from backtesting.visualization import PerformanceVisualizer

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print("‚úì All imports successful")
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

## 2. Configuration

In [None]:
# Configuration parameters
CONFIG = {
    # Data parameters
    'start_date': '2020-01-01',
    'end_date': '2024-12-31',
    'num_stocks': 15,  # Top N stocks by quarterly returns
    'total_assets': 17,  # 15 stocks + 2 ETFs (GOLDBEES, SILVERBEES)
    
    # Stock selection
    'selection_period': 63,  # 3 months (approximately 63 trading days)
    'rebalance_frequency': 'quarterly',  # How often to reselect stocks
    
    # Model parameters
    'stock_embedding_dim': 64,
    'market_embedding_dim': 32,
    'num_transformer_heads': 4,
    'num_transformer_layers': 2,
    'policy_hidden_dim': 64,
    'value_hidden_dim': 128,
    'dropout': 0.1,
    'max_weight': 0.25,  # Maximum weight per asset
    
    # Training parameters
    'n_episodes': 100,
    'episode_length': 252,  # One year
    'learning_rate': 3e-4,
    'gamma': 0.99,
    'gae_lambda': 0.95,
    'clip_epsilon': 0.2,
    'value_coef': 0.5,
    'entropy_coef': 0.01,
    'batch_size': 64,
    'buffer_size': 2048,
    
    # Environment parameters
    'transaction_cost': 0.001,
    'turnover_penalty': 0.0005,
    'initial_cash': 1000000.0,
    
    # Feature parameters
    'lookback_window': 20,
    
    # Device
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

print("Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

## 3. Data Fetching with yfinance

### 3.1 NIFTY 50 Stock Universe

In [None]:
# NIFTY 50 constituents (as of 2024)
NIFTY_50_STOCKS = [
    'ADANIENT', 'ADANIPORTS', 'APOLLOHOSP', 'ASIANPAINT', 'AXISBANK',
    'BAJAJ-AUTO', 'BAJFINANCE', 'BAJAJFINSV', 'BPCL', 'BHARTIARTL',
    'BRITANNIA', 'CIPLA', 'COALINDIA', 'DIVISLAB', 'DRREDDY',
    'EICHERMOT', 'GRASIM', 'HCLTECH', 'HDFCBANK', 'HDFCLIFE',
    'HEROMOTOCO', 'HINDALCO', 'HINDUNILVR', 'ICICIBANK', 'ITC',
    'INDUSINDBK', 'INFY', 'JSWSTEEL', 'KOTAKBANK', 'LT',
    'M&M', 'MARUTI', 'NTPC', 'NESTLEIND', 'ONGC',
    'POWERGRID', 'RELIANCE', 'SBILIFE', 'SBIN', 'SUNPHARMA',
    'TCS', 'TATACONSUM', 'TATAMOTORS', 'TATASTEEL', 'TECHM',
    'TITAN', 'ULTRACEMCO', 'UPL', 'WIPRO', 'LTIM'
]

# Add .NS suffix for NSE
NIFTY_50_TICKERS = [f"{stock}.NS" for stock in NIFTY_50_STOCKS]

# Precious metals ETFs
PRECIOUS_METALS = ['GOLDBEES.NS', 'SILVERBEES.NS']

print(f"NIFTY 50 Universe: {len(NIFTY_50_TICKERS)} stocks")
print(f"Precious Metals ETFs: {PRECIOUS_METALS}")

### 3.2 Download Historical Data

In [None]:
def download_stock_data(tickers: List[str], start_date: str, end_date: str) -> pd.DataFrame:
    """
    Download stock data from yfinance.
    
    Args:
        tickers: List of stock tickers
        start_date: Start date (YYYY-MM-DD)
        end_date: End date (YYYY-MM-DD)
        
    Returns:
        DataFrame with OHLCV data
    """
    print(f"Downloading data for {len(tickers)} tickers...")
    
    data_list = []
    failed_tickers = []
    
    for ticker in tqdm(tickers, desc="Downloading"):
        try:
            df = yf.download(ticker, start=start_date, end=end_date, progress=False)
            
            if df.empty or len(df) < 100:
                failed_tickers.append(ticker)
                continue
            
            df = df.reset_index()
            df['Ticker'] = ticker
            df['Stock'] = ticker.replace('.NS', '')
            
            # Handle timezone
            if 'Date' in df.columns:
                if df['Date'].dt.tz is not None:
                    df['Date'] = df['Date'].dt.tz_localize(None)
            
            data_list.append(df[['Date', 'Stock', 'Open', 'High', 'Low', 'Close', 'Volume']])
            
        except Exception as e:
            print(f"  Failed to download {ticker}: {e}")
            failed_tickers.append(ticker)
            continue
    
    if failed_tickers:
        print(f"\n‚ö†Ô∏è  Failed tickers ({len(failed_tickers)}): {failed_tickers[:5]}...")
    
    if not data_list:
        raise ValueError("No data downloaded successfully!")
    
    combined_df = pd.concat(data_list, ignore_index=True)
    combined_df = combined_df.sort_values(['Date', 'Stock']).reset_index(drop=True)
    
    print(f"\n‚úì Downloaded {combined_df['Stock'].nunique()} stocks")
    print(f"  Date range: {combined_df['Date'].min().date()} to {combined_df['Date'].max().date()}")
    print(f"  Total records: {len(combined_df):,}")
    
    return combined_df


# Download NIFTY 50 data
print("=" * 80)
print("DOWNLOADING NIFTY 50 DATA")
print("=" * 80)
nifty_data = download_stock_data(NIFTY_50_TICKERS, CONFIG['start_date'], CONFIG['end_date'])

## 4. Dynamic Stock Selection (Quarterly Returns)

In [None]:
def calculate_quarterly_returns(df: pd.DataFrame, lookback_days: int = 63) -> pd.DataFrame:
    """
    Calculate quarterly (3-month) returns for each stock.
    
    Args:
        df: DataFrame with stock data
        lookback_days: Number of days for quarterly calculation (default 63 ~ 3 months)
        
    Returns:
        DataFrame with quarterly returns
    """
    returns_list = []
    
    for stock in df['Stock'].unique():
        stock_data = df[df['Stock'] == stock].sort_values('Date').copy()
        
        if len(stock_data) < lookback_days:
            continue
        
        # Calculate quarterly return (last 3 months)
        latest_price = stock_data['Close'].iloc[-1]
        past_price = stock_data['Close'].iloc[-lookback_days]
        quarterly_return = (latest_price / past_price - 1) * 100
        
        returns_list.append({
            'Stock': stock,
            'QuarterlyReturn': quarterly_return,
            'LatestPrice': latest_price,
            'DataPoints': len(stock_data)
        })
    
    returns_df = pd.DataFrame(returns_list)
    returns_df = returns_df.sort_values('QuarterlyReturn', ascending=False).reset_index(drop=True)
    
    return returns_df


def select_top_performers(df: pd.DataFrame, returns_df: pd.DataFrame, top_n: int = 15) -> pd.DataFrame:
    """
    Select top N stocks by quarterly returns.
    
    Args:
        df: Full stock data
        returns_df: DataFrame with quarterly returns (sorted)
        top_n: Number of top stocks to select
        
    Returns:
        Filtered DataFrame with top performers
    """
    top_stocks = returns_df.head(top_n)['Stock'].tolist()
    filtered_df = df[df['Stock'].isin(top_stocks)].copy()
    
    return filtered_df, top_stocks


# Calculate quarterly returns
print("\n" + "=" * 80)
print("DYNAMIC STOCK SELECTION")
print("=" * 80)

quarterly_returns = calculate_quarterly_returns(nifty_data, CONFIG['selection_period'])

print(f"\nTop 15 Stocks by Quarterly Returns:")
print("-" * 60)
print(quarterly_returns.head(15).to_string(index=False))

# Select top 15 performers
selected_stock_data, selected_stocks = select_top_performers(
    nifty_data, 
    quarterly_returns, 
    CONFIG['num_stocks']
)

print(f"\n‚úì Selected {len(selected_stocks)} top-performing stocks:")
print(f"  {', '.join(selected_stocks)}")

## 5. Add Precious Metals ETFs

In [None]:
# Download precious metals ETF data
print("\n" + "=" * 80)
print("DOWNLOADING PRECIOUS METALS ETFs")
print("=" * 80)

precious_metals_data = download_stock_data(
    PRECIOUS_METALS, 
    CONFIG['start_date'], 
    CONFIG['end_date']
)

# Combine selected stocks with precious metals
all_assets_data = pd.concat([selected_stock_data, precious_metals_data], ignore_index=True)
all_assets_data = all_assets_data.sort_values(['Date', 'Stock']).reset_index(drop=True)

print(f"\n‚úì Total portfolio assets: {all_assets_data['Stock'].nunique()}")
print(f"  - Stocks: {len(selected_stocks)}")
print(f"  - ETFs: {len(PRECIOUS_METALS)}")
print(f"\n  Assets: {sorted(all_assets_data['Stock'].unique())}")

## 6. Feature Engineering

### 6.1 Technical Indicators (ATR, MFI)

In [None]:
def calculate_atr(df: pd.DataFrame, period: int = 14) -> pd.Series:
    """
    Calculate Average True Range (ATR).
    """
    high_low = df['High'] - df['Low']
    high_close = np.abs(df['High'] - df['Close'].shift())
    low_close = np.abs(df['Low'] - df['Close'].shift())
    
    true_range = pd.concat([high_low, high_close, low_close], axis=1).max(axis=1)
    atr = true_range.rolling(window=period).mean()
    
    return atr


def calculate_mfi(df: pd.DataFrame, period: int = 14) -> pd.Series:
    """
    Calculate Money Flow Index (MFI).
    """
    typical_price = (df['High'] + df['Low'] + df['Close']) / 3
    money_flow = typical_price * df['Volume']
    
    # Positive and negative money flow
    positive_flow = pd.Series(0.0, index=df.index)
    negative_flow = pd.Series(0.0, index=df.index)
    
    for i in range(1, len(df)):
        if typical_price.iloc[i] > typical_price.iloc[i-1]:
            positive_flow.iloc[i] = money_flow.iloc[i]
        elif typical_price.iloc[i] < typical_price.iloc[i-1]:
            negative_flow.iloc[i] = money_flow.iloc[i]
    
    positive_mf = positive_flow.rolling(window=period).sum()
    negative_mf = negative_flow.rolling(window=period).sum()
    
    mfi = 100 - (100 / (1 + positive_mf / (negative_mf + 1e-10)))
    
    return mfi


def add_technical_indicators(df: pd.DataFrame) -> pd.DataFrame:
    """
    Add technical indicators to stock data.
    """
    df = df.copy()
    features_list = []
    
    for stock in tqdm(df['Stock'].unique(), desc="Computing indicators"):
        stock_data = df[df['Stock'] == stock].sort_values('Date').copy()
        
        # ATR
        stock_data['ATR_14'] = calculate_atr(stock_data, period=14)
        
        # MFI
        stock_data['MFI_14'] = calculate_mfi(stock_data, period=14)
        
        features_list.append(stock_data)
    
    return pd.concat(features_list, ignore_index=True)


# Add technical indicators
print("\n" + "=" * 80)
print("FEATURE ENGINEERING")
print("=" * 80)

all_assets_data = add_technical_indicators(all_assets_data)

print("\n‚úì Added technical indicators: ATR, MFI")

### 6.2 Stock Features

In [None]:
# Compute stock features using existing FeatureEngineering class
feature_eng = FeatureEngineering(lookback_window=CONFIG['lookback_window'])
features_df = feature_eng.compute_stock_features(all_assets_data)

print(f"\n‚úì Computed stock features")
print(f"  Features: {[col for col in features_df.columns if col not in ['Date', 'Stock', 'Open', 'High', 'Low', 'Close', 'Volume']]}")
print(f"  Shape: {features_df.shape}")

### 6.3 Market Features

In [None]:
def compute_market_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Compute market-level features.
    """
    # Calculate equal-weighted market return
    market_returns = df.groupby('Date')['Returns'].mean().reset_index()
    market_returns.columns = ['Date', 'MarketReturn']
    
    # Market volatility
    market_returns['MarketVolatility'] = market_returns['MarketReturn'].rolling(20).std()
    
    # Market momentum
    market_returns['MarketMomentum_5'] = market_returns['MarketReturn'].rolling(5).mean()
    market_returns['MarketMomentum_20'] = market_returns['MarketReturn'].rolling(20).mean()
    
    # Dispersion (cross-sectional std)
    dispersion = df.groupby('Date')['Returns'].std().reset_index()
    dispersion.columns = ['Date', 'MarketDispersion']
    market_returns = market_returns.merge(dispersion, on='Date', how='left')
    
    return market_returns


market_features = compute_market_features(features_df)

print(f"\n‚úì Computed market features:")
print(f"  {list(market_features.columns)}")

### 6.4 Prepare Data for Training

In [None]:
# Merge market features into stock data
features_df = features_df.merge(market_features, on='Date', how='left')

# Drop NaN rows (from rolling calculations)
features_df = features_df.dropna().reset_index(drop=True)

print(f"\n‚úì Prepared features:")
print(f"  Shape: {features_df.shape}")
print(f"  Date range: {features_df['Date'].min().date()} to {features_df['Date'].max().date()}")
print(f"  Assets: {features_df['Stock'].nunique()}")

## 7. Preprocessing and Normalization

In [None]:
# Define feature columns
stock_feature_cols = [
    'Returns', 'LogReturns', 'Momentum_5', 'Momentum_10', 'Momentum_20',
    'Volatility_5', 'Volatility_10', 'Volatility_20',
    'HighLow_Range', 'OpenClose_Range',
    'Price_to_SMA5', 'Price_to_SMA10', 'Price_to_SMA20',
    'Volume_Change', 'Volume_Ratio', 'RSI_14',
    'ATR_14', 'MFI_14'
]

market_feature_cols = [
    'MarketReturn', 'MarketVolatility',
    'MarketMomentum_5', 'MarketMomentum_20',
    'MarketDispersion'
]

# Use FeaturePreprocessor
preprocessor = FeaturePreprocessor(
    stock_feature_cols=stock_feature_cols,
    market_feature_cols=market_feature_cols
)

# Fit and transform
stock_sequences, market_sequences, returns, dates, stock_names = preprocessor.prepare_sequences(
    features_df
)

print(f"\n‚úì Data preprocessed and normalized:")
print(f"  Stock sequences shape: {stock_sequences.shape}")
print(f"  Market sequences shape: {market_sequences.shape}")
print(f"  Returns shape: {returns.shape}")
print(f"  Number of timesteps: {len(dates)}")
print(f"  Number of assets: {len(stock_names)}")
print(f"  Stock names: {stock_names}")

## 8. Create Trading Environment

In [None]:
# Create trading environment
env = TradingEnvironment(
    stock_sequences=stock_sequences,
    market_sequences=market_sequences,
    returns=returns,
    dates=dates,
    transaction_cost=CONFIG['transaction_cost'],
    turnover_penalty=CONFIG['turnover_penalty'],
    initial_cash=CONFIG['initial_cash'],
    normalize_rewards=True,
    random_start=True,
    episode_length=CONFIG['episode_length']
)

print(f"\n‚úì Trading environment created:")
print(f"  Assets: {env.num_stocks}")
print(f"  Stock features: {env.num_stock_features}")
print(f"  Market features: {env.num_market_features}")
print(f"  Episode length: {env.episode_length} days")

## 9. Initialize Transformer-PPO Agent

In [None]:
# Create PPO agent
agent = PPOAgent(
    num_stock_features=env.num_stock_features,
    num_market_features=env.num_market_features,
    num_stocks=env.num_stocks,
    stock_embedding_dim=CONFIG['stock_embedding_dim'],
    market_embedding_dim=CONFIG['market_embedding_dim'],
    num_transformer_heads=CONFIG['num_transformer_heads'],
    num_transformer_layers=CONFIG['num_transformer_layers'],
    policy_hidden_dim=CONFIG['policy_hidden_dim'],
    value_hidden_dim=CONFIG['value_hidden_dim'],
    dropout=CONFIG['dropout'],
    max_weight=CONFIG['max_weight'],
    device=CONFIG['device']
)

# Count parameters
total_params = sum(p.numel() for p in agent.parameters())
trainable_params = sum(p.numel() for p in agent.parameters() if p.requires_grad)

print(f"\n‚úì Transformer-PPO Agent initialized:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Device: {CONFIG['device']}")

# Print model architecture
print("\nModel Architecture:")
print(agent)

## 10. Training Loop

In [None]:
# Create trainer
trainer = PPOTrainer(
    agent=agent,
    env=env,
    learning_rate=CONFIG['learning_rate'],
    gamma=CONFIG['gamma'],
    gae_lambda=CONFIG['gae_lambda'],
    clip_epsilon=CONFIG['clip_epsilon'],
    value_coef=CONFIG['value_coef'],
    entropy_coef=CONFIG['entropy_coef'],
    batch_size=CONFIG['batch_size'],
    buffer_size=CONFIG['buffer_size'],
    device=CONFIG['device']
)

print(f"\n‚úì PPO Trainer initialized")

In [None]:
# Train the agent
print("\n" + "=" * 80)
print("TRAINING TRANSFORMER-PPO AGENT")
print("=" * 80)

training_history = {
    'episode': [],
    'reward': [],
    'length': [],
    'sharpe': [],
    'max_drawdown': []
}

for episode in tqdm(range(CONFIG['n_episodes']), desc="Training"):
    # Collect rollout
    rollout_stats = trainer.collect_rollout(n_steps=CONFIG['buffer_size'])
    
    # Update policy
    update_stats = trainer.update()
    
    # Get episode statistics
    episode_stats = env.get_episode_statistics()
    
    # Log
    training_history['episode'].append(episode)
    training_history['reward'].append(rollout_stats['mean_reward'])
    training_history['length'].append(rollout_stats['mean_length'])
    training_history['sharpe'].append(episode_stats.get('sharpe_ratio', 0))
    training_history['max_drawdown'].append(episode_stats.get('max_drawdown', 0))
    
    # Print progress
    if (episode + 1) % 10 == 0:
        print(f"\nEpisode {episode + 1}/{CONFIG['n_episodes']}")
        print(f"  Reward: {rollout_stats['mean_reward']:.4f}")
        print(f"  Sharpe: {episode_stats.get('sharpe_ratio', 0):.2f}")
        print(f"  MaxDD: {episode_stats.get('max_drawdown', 0):.2%}")
        print(f"  Policy Loss: {update_stats['policy_loss']:.4f}")
        print(f"  Value Loss: {update_stats['value_loss']:.4f}")

print("\n‚úì Training completed!")

## 11. Training Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Rewards
axes[0, 0].plot(training_history['episode'], training_history['reward'])
axes[0, 0].set_title('Episode Reward')
axes[0, 0].set_xlabel('Episode')
axes[0, 0].set_ylabel('Mean Reward')
axes[0, 0].grid(True)

# Sharpe ratio
axes[0, 1].plot(training_history['episode'], training_history['sharpe'])
axes[0, 1].set_title('Sharpe Ratio')
axes[0, 1].set_xlabel('Episode')
axes[0, 1].set_ylabel('Sharpe Ratio')
axes[0, 1].grid(True)

# Max drawdown
axes[1, 0].plot(training_history['episode'], training_history['max_drawdown'])
axes[1, 0].set_title('Maximum Drawdown')
axes[1, 0].set_xlabel('Episode')
axes[1, 0].set_ylabel('Max Drawdown')
axes[1, 0].grid(True)

# Episode length
axes[1, 1].plot(training_history['episode'], training_history['length'])
axes[1, 1].set_title('Episode Length')
axes[1, 1].set_xlabel('Episode')
axes[1, 1].set_ylabel('Days')
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()

## 12. Backtesting

In [None]:
# Create backtest engine
backtest_engine = BacktestEngine(
    agent=agent,
    stock_sequences=stock_sequences,
    market_sequences=market_sequences,
    returns=returns,
    dates=dates,
    stock_names=stock_names,
    transaction_cost=CONFIG['transaction_cost'],
    initial_cash=CONFIG['initial_cash']
)

print("\n" + "=" * 80)
print("BACKTESTING")
print("=" * 80)

# Run backtest
backtest_results = backtest_engine.run_backtest(deterministic=True)

print(f"\n‚úì Backtest completed")
print(f"  Total days: {len(backtest_results['portfolio_values'])}")
print(f"  Date range: {dates[0].date()} to {dates[-1].date()}")

## 13. Performance Metrics

In [None]:
# Calculate performance metrics
metrics = PerformanceMetrics()
performance = metrics.calculate_metrics(
    portfolio_values=backtest_results['portfolio_values'],
    dates=dates,
    initial_cash=CONFIG['initial_cash']
)

print("\n" + "=" * 80)
print("PERFORMANCE METRICS")
print("=" * 80)

print(f"\nTotal Return: {performance['total_return']:.2%}")
print(f"Annualized Return: {performance['annualized_return']:.2%}")
print(f"Annualized Volatility: {performance['annualized_volatility']:.2%}")
print(f"Sharpe Ratio: {performance['sharpe_ratio']:.2f}")
print(f"Maximum Drawdown: {performance['max_drawdown']:.2%}")
print(f"Calmar Ratio: {performance['calmar_ratio']:.2f}")
print(f"Win Rate: {performance['win_rate']:.2%}")

## 14. Visualization

### 14.1 Portfolio Value Over Time

In [None]:
# Portfolio value over time
visualizer = PerformanceVisualizer()
visualizer.plot_portfolio_value(
    portfolio_values=backtest_results['portfolio_values'],
    dates=dates,
    benchmark_values=None  # Can add NIFTY 50 benchmark
)

### 14.2 Drawdown

In [None]:
# Drawdown chart
visualizer.plot_drawdown(
    portfolio_values=backtest_results['portfolio_values'],
    dates=dates
)

### 14.3 Monthly Returns Heatmap

In [None]:
# Calculate monthly returns
portfolio_df = pd.DataFrame({
    'Date': dates,
    'Value': backtest_results['portfolio_values']
})
portfolio_df['Date'] = pd.to_datetime(portfolio_df['Date'])
portfolio_df = portfolio_df.set_index('Date')
portfolio_df['Returns'] = portfolio_df['Value'].pct_change()

# Resample to monthly
monthly_returns = portfolio_df['Returns'].resample('M').apply(lambda x: (1 + x).prod() - 1)
monthly_returns_pivot = monthly_returns.to_frame()
monthly_returns_pivot['Year'] = monthly_returns_pivot.index.year
monthly_returns_pivot['Month'] = monthly_returns_pivot.index.month
monthly_returns_pivot = monthly_returns_pivot.pivot(index='Year', columns='Month', values='Returns')

# Plot heatmap
plt.figure(figsize=(12, 8))
sns.heatmap(
    monthly_returns_pivot * 100, 
    annot=True, 
    fmt='.1f', 
    cmap='RdYlGn', 
    center=0,
    cbar_kws={'label': 'Return (%)'}
)
plt.title('Monthly Returns Heatmap (%)', fontsize=14, fontweight='bold')
plt.xlabel('Month')
plt.ylabel('Year')
plt.tight_layout()
plt.show()

## 15. Final Asset Weights

In [None]:
# Get final portfolio weights
final_weights = backtest_results['weights'][-1]

# Create weights DataFrame
weights_df = pd.DataFrame({
    'Asset': stock_names,
    'Weight': final_weights,
    'Weight_Pct': final_weights * 100
})
weights_df = weights_df.sort_values('Weight', ascending=False).reset_index(drop=True)

print("\n" + "=" * 80)
print("FINAL PORTFOLIO WEIGHTS")
print("=" * 80)
print(weights_df.to_string(index=False))

# Highlight precious metals
gold_weight = weights_df[weights_df['Asset'] == 'GOLDBEES']['Weight_Pct'].values
silver_weight = weights_df[weights_df['Asset'] == 'SILVERBEES']['Weight_Pct'].values

print(f"\nüìä Precious Metals Allocation:")
if len(gold_weight) > 0:
    print(f"  GOLDBEES: {gold_weight[0]:.2f}%")
if len(silver_weight) > 0:
    print(f"  SILVERBEES: {silver_weight[0]:.2f}%")
    
total_precious = (gold_weight[0] if len(gold_weight) > 0 else 0) + (silver_weight[0] if len(silver_weight) > 0 else 0)
print(f"  Total Precious Metals: {total_precious:.2f}%")

### 15.1 Weight Visualization

In [None]:
# Pie chart of final weights
fig, ax = plt.subplots(figsize=(12, 8))

# Highlight precious metals with different colors
colors = ['gold' if asset == 'GOLDBEES' else 'silver' if asset == 'SILVERBEES' else None 
          for asset in weights_df['Asset']]

wedges, texts, autotexts = ax.pie(
    weights_df['Weight'],
    labels=weights_df['Asset'],
    autopct='%1.1f%%',
    colors=colors,
    startangle=90
)

# Enhance text
for text in texts:
    text.set_fontsize(10)
for autotext in autotexts:
    autotext.set_color('white')
    autotext.set_fontweight('bold')

ax.set_title('Final Portfolio Allocation', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

### 15.2 Weight Evolution Over Time

In [None]:
# Plot weight evolution
weights_over_time = np.array(backtest_results['weights'])

# Plot for precious metals and top 5 stocks
fig, ax = plt.subplots(figsize=(15, 8))

# Find indices of precious metals and top stocks
gold_idx = stock_names.index('GOLDBEES') if 'GOLDBEES' in stock_names else None
silver_idx = stock_names.index('SILVERBEES') if 'SILVERBEES' in stock_names else None
top_stocks_idx = weights_df.head(5)['Asset'].tolist()

# Plot precious metals with bold lines
if gold_idx is not None:
    ax.plot(dates, weights_over_time[:, gold_idx], label='GOLDBEES', linewidth=2.5, color='gold')
if silver_idx is not None:
    ax.plot(dates, weights_over_time[:, silver_idx], label='SILVERBEES', linewidth=2.5, color='silver')

# Plot top 5 stocks
for stock in top_stocks_idx:
    if stock not in ['GOLDBEES', 'SILVERBEES']:
        idx = stock_names.index(stock)
        ax.plot(dates, weights_over_time[:, idx], label=stock, alpha=0.7)

ax.set_xlabel('Date')
ax.set_ylabel('Weight')
ax.set_title('Portfolio Weight Evolution (Top Assets)', fontsize=14, fontweight='bold')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 16. Summary Report

In [None]:
print("\n" + "=" * 80)
print("TRANSFORMER-PPO PORTFOLIO OPTIMIZATION - SUMMARY")
print("=" * 80)

print("\nüìã Portfolio Composition:")
print(f"  Total Assets: {len(stock_names)}")
print(f"  - Stocks (Top 15 NIFTY 50): {len([s for s in stock_names if s not in ['GOLDBEES', 'SILVERBEES']])}")
print(f"  - Precious Metals ETFs: {len([s for s in stock_names if s in ['GOLDBEES', 'SILVERBEES']])}")

print("\nüìà Performance Summary:")
print(f"  Period: {dates[0].date()} to {dates[-1].date()}")
print(f"  Total Return: {performance['total_return']:.2%}")
print(f"  Annualized Return: {performance['annualized_return']:.2%}")
print(f"  Volatility: {performance['annualized_volatility']:.2%}")
print(f"  Sharpe Ratio: {performance['sharpe_ratio']:.2f}")
print(f"  Max Drawdown: {performance['max_drawdown']:.2%}")

print("\nüèÜ Top 5 Holdings:")
for i, row in weights_df.head(5).iterrows():
    symbol = 'ü•á' if row['Asset'] == 'GOLDBEES' else 'ü•à' if row['Asset'] == 'SILVERBEES' else 'üìä'
    print(f"  {symbol} {row['Asset']:<15} {row['Weight_Pct']:>6.2f}%")

print("\nüíé Precious Metals Exposure:")
print(f"  Total allocation: {total_precious:.2f}%")
if len(gold_weight) > 0:
    print(f"  - Gold (GOLDBEES): {gold_weight[0]:.2f}%")
if len(silver_weight) > 0:
    print(f"  - Silver (SILVERBEES): {silver_weight[0]:.2f}%")

print("\n‚úì Analysis complete!")
print("=" * 80)

## 17. Save Results

In [None]:
# Save model checkpoint
checkpoint_path = '../checkpoints/transformer_ppo_precious_metals.pt'
os.makedirs('../checkpoints', exist_ok=True)

torch.save({
    'model_state_dict': agent.state_dict(),
    'config': CONFIG,
    'stock_names': stock_names,
    'performance': performance,
    'final_weights': weights_df.to_dict('records')
}, checkpoint_path)

print(f"‚úì Model saved to {checkpoint_path}")

# Save results to CSV
results_df = pd.DataFrame({
    'Date': dates,
    'PortfolioValue': backtest_results['portfolio_values']
})
results_df.to_csv('../results/transformer_ppo_precious_metals_results.csv', index=False)
print(f"‚úì Results saved to ../results/transformer_ppo_precious_metals_results.csv")

# Save weights
weights_df.to_csv('../results/final_weights_precious_metals.csv', index=False)
print(f"‚úì Weights saved to ../results/final_weights_precious_metals.csv")