# StockFormer Attention Analysis

Interactive exploration of the cross-attention StockFormer's internal representations.
This notebook extracts and visualizes:

1. **Cross-attention heatmaps** — which market timesteps each stock timestep attends to
2. **Gate activations** — when the model uses market context vs stock-specific signal
3. **Self-attention patterns** — temporal dependencies within the stock sequence
4. **Attention pooling** — which timesteps matter most for the final prediction
5. **Head specialization** — do different heads learn different patterns?
6. **Input saliency** — which features drive the prediction (gradient-based)

## Setup
Run from the project root: `jupyter notebook notebooks/attention_analysis.ipynb`

In [None]:
import sys
sys.path.insert(0, '..')  # add project root to path

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
plt.rcParams['figure.dpi'] = 120

from stockformer.model import CrossAttentionStockTransformer, create_model, infer_arch_from_state_dict
from stockformer.config import BASE_FEATURE_COLUMNS, MARKET_FEATURE_COLUMNS
from stockformer.explainability import (
    AttentionExtractor,
    make_extractor,
    compute_input_saliency,
    plot_attention_heatmap,
    plot_attention_all_heads,
    plot_cross_attention,
    plot_cross_attention_all_heads,
    plot_self_attention,
    plot_gate_activations,
    plot_gate_timeseries,
    plot_attention_pooling,
    plot_layer_evolution,
    plot_head_specialization,
    plot_saliency_map,
    plot_saliency_summary,
    generate_report,
)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Load Model

Two options:
- **Option A**: Load a trained checkpoint
- **Option B**: Create a randomly initialized model (for testing the viz pipeline)

In [None]:
# ---- Option A: Load from checkpoint ----
# Uncomment and set your checkpoint path:
#
# CHECKPOINT = '../stockformer/output/models/binary_3d_best.pt'
# ckpt = torch.load(CHECKPOINT, map_location=device, weights_only=False)
# arch = infer_arch_from_state_dict(ckpt['model_state_dict'], CHECKPOINT)
# cfg_from_ckpt = {
#     'd_model': arch['d_model'], 'nhead': arch['nhead'],
#     'num_layers': arch['num_layers'], 'dim_feedforward': arch['dim_feedforward'],
#     'dropout': arch.get('dropout', 0.1), 'market_layers': arch.get('market_layers', 2),
#     'market_feature_dim': arch.get('market_dim', len(MARKET_FEATURE_COLUMNS)),
#     'layer_drop': 0.0,
#     'loss_name': 'coral' if arch.get('use_coral') else None,
# }
# model = create_model(
#     feature_dim=arch['feature_dim'],
#     label_mode=arch['output_mode'],
#     bucket_edges=list(range(arch.get('num_buckets', 5) - 1)) if arch['output_mode'] == 'buckets' else None,
#     cfg=cfg_from_ckpt,
#     model_type='cross_attention',
# )
# model.load_state_dict(ckpt['model_state_dict'])
# model = model.to(device).eval()
# print(f'Loaded: {CHECKPOINT}')
# print(f'Architecture: {arch}')

# ---- Option B: Random model for testing ----
FEATURE_DIM = len(BASE_FEATURE_COLUMNS)  # 41
MARKET_DIM = len(MARKET_FEATURE_COLUMNS)  # 11
SEQ_LEN = 60  # lookback window

cfg = {
    'd_model': 128, 'nhead': 8, 'num_layers': 4, 'dim_feedforward': 512,
    'dropout': 0.1, 'market_layers': 2, 'market_feature_dim': MARKET_DIM,
    'layer_drop': 0.0,
}
model = create_model(
    feature_dim=FEATURE_DIM,
    label_mode='binary',
    bucket_edges=None,
    cfg=cfg,
    model_type='cross_attention',
)
model = model.to(device).eval()

n_params = sum(p.numel() for p in model.parameters())
print(f'Model: CrossAttentionStockTransformer ({n_params:,} params)')
print(f'Features: {FEATURE_DIM} stock, {MARKET_DIM} market, seq_len={SEQ_LEN}')

## 2. Prepare Input

Either use real data from a CSV or generate random tensors for testing.

In [None]:
# ---- Option A: Real data ----
# import pandas as pd
# from stockformer.dataset import StockDataset
#
# df = pd.read_csv('../data/all_data_AAPL.csv')
# dataset = StockDataset(df, feature_cols=BASE_FEATURE_COLUMNS,
#                        market_feature_cols=MARKET_FEATURE_COLUMNS,
#                        target_col='close_3d_fwd_return',
#                        lookback=SEQ_LEN, label_mode='binary')
# x_real, y_real, market_real = dataset[-1]  # last sample
# x = x_real.unsqueeze(0).to(device)
# market_x = market_real.unsqueeze(0).to(device)

# ---- Option B: Random tensors for testing ----
torch.manual_seed(42)
x = torch.randn(1, SEQ_LEN, FEATURE_DIM).to(device)
market_x = torch.randn(1, SEQ_LEN, MARKET_DIM).to(device)

print(f'Stock input:  {x.shape}')
print(f'Market input: {market_x.shape}')

## 3. Extract Attention Data

The `AttentionExtractor` registers forward hooks that capture:
- Cross-attention weights from each `GatedCrossAttention` layer
- Gate values (sigmoid output controlling market context injection)
- Self-attention weights from each `TransformerEncoderLayer`
- Attention pooling weights

In [None]:
extractor = AttentionExtractor(model)

with torch.no_grad():
    with extractor:
        output = model(x, market_x)

attn = extractor.get_data()

print(f'Output shape: {output.shape}')
print(f'Cross-attn layers captured: {len(attn["cross_attn_weights"])}')
print(f'Gate layers captured:       {len(attn["gate_values"])}')
print(f'Self-attn layers captured:  {len(attn["self_attn_weights"])}')
print(f'Pool weights shape:         {attn["pool_weights"].shape}')
print()
print(f'Cross-attn weight shape (per layer): {attn["cross_attn_weights"][0].shape}')
print(f'  -> [batch, heads, stock_len, market_len]')
print(f'Gate value shape (per layer):        {attn["gate_values"][0].shape}')
print(f'  -> [batch, seq_len, d_model]')

## 4. Cross-Attention Heatmaps

These show which market timesteps each stock timestep attends to.

**What to look for:**
- Diagonal pattern = each stock timestep attends to same-day market data
- Vertical stripes = all stock timesteps attend to specific market events
- Diffuse = model hasn't learned meaningful cross-attention yet

In [None]:
# Average attention across heads for each layer
for layer in range(len(attn['cross_attn_weights'])):
    fig = plot_cross_attention(attn, layer=layer)
    plt.show()

In [None]:
# All 8 heads for layer 0 — do they specialize?
fig = plot_cross_attention_all_heads(attn, layer=0)
plt.show()

## 5. Gate Activations

The gate controls how much market context is injected:
- **Gate near 0**: stock-specific signal dominates (calm market)
- **Gate near 1**: market context dominates (crash/rally)

The timeseries plot shows mean gate value over the lookback window.

In [None]:
# Gate activation heatmap (top 32 most variable dimensions)
fig = plot_gate_activations(attn, layer=0)
plt.show()

In [None]:
# Mean gate over time — all layers overlaid
fig = plot_gate_timeseries(attn)
plt.show()

## 6. Self-Attention Patterns

Stock self-attention: how the stock sequence attends to itself across time.

**What to look for:**
- Diagonal = local attention (each day attends to itself)
- Lower-triangular = causal pattern (attends to past only)
- Off-diagonal hot spots = the model links distant days

In [None]:
# Self-attention for layer 0 (avg over heads)
fig = plot_self_attention(attn, layer=0)
plt.show()

In [None]:
# Compare layer 0 vs last layer
n_layers = len(attn['self_attn_weights'])
if n_layers > 1:
    fig = plot_self_attention(attn, layer=n_layers - 1,
                              title=f'Self-Attention — Layer {n_layers-1} (last), Avg heads')
    plt.show()

## 7. Attention Pooling

After all encoder layers, the model uses learned attention pooling to
aggregate the sequence into a single vector for the output head.

**What to look for:**
- Peak at recent timesteps = recency bias (expected)
- Peaks at earlier timesteps = long-range dependency detection

In [None]:
fig = plot_attention_pooling(attn)
plt.show()

## 8. Layer Evolution

How attention patterns change across layers.

- **Entropy decreasing** across layers = model sharpens attention (learns to focus)
- **Entropy flat/high** = model hasn't learned to differentiate (possible underfitting)
- **Sparsity increasing** = deeper layers attend to fewer positions

In [None]:
fig = plot_layer_evolution(attn, attention_type='cross')
plt.show()

fig = plot_layer_evolution(attn, attention_type='self')
plt.show()

## 9. Head Specialization

Do different attention heads learn different behaviors?

- **Different entropies** = some heads are focused, others diffuse (good)
- **Different peak positions** = heads attend to different temporal regions (good)
- **All heads identical** = redundancy, capacity waste (bad)

In [None]:
fig = plot_head_specialization(attn, layer=0, attention_type='cross')
plt.show()

fig = plot_head_specialization(attn, layer=0, attention_type='self')
plt.show()

## 10. Input Saliency (Gradient-Based)

Uses backpropagation to compute which input features the model is most
sensitive to. This is different from attention — saliency measures the
**causal effect** of each input on the output, not just what the model
"looks at".

Method: Input x Gradient (Grad-CAM analog for tabular data)

In [None]:
stock_sal, market_sal = compute_input_saliency(model, x, market_x)

print(f'Stock saliency shape:  {stock_sal.shape}  (seq_len x feature_dim)')
print(f'Market saliency shape: {market_sal.shape}  (seq_len x market_dim)')

In [None]:
# Stock feature saliency heatmap (top 20 features)
fig = plot_saliency_map(stock_sal, feature_names=BASE_FEATURE_COLUMNS, top_k=20)
plt.show()

In [None]:
# Market feature saliency heatmap (all 11 features)
fig = plot_saliency_map(market_sal, feature_names=MARKET_FEATURE_COLUMNS,
                        title='Market Input Saliency', top_k=11)
plt.show()

In [None]:
# Side-by-side feature importance summary
fig = plot_saliency_summary(
    stock_sal, market_sal,
    stock_feature_names=BASE_FEATURE_COLUMNS,
    market_feature_names=MARKET_FEATURE_COLUMNS,
)
plt.show()

## 11. Full Report Generation

Generate all visualizations at once and save to disk as PNGs.

In [None]:
# Generate all plots and save to disk
saved = generate_report(
    model, x, market_x,
    output_dir='../stockformer/output/explainability',
    input_names=['Stock', 'Market'],
    feature_names=[BASE_FEATURE_COLUMNS, MARKET_FEATURE_COLUMNS],
)

print(f'\nGenerated {len(saved)} figures:')
for name, path in saved.items():
    print(f'  {name}: {path}')

## 12. Comparing Samples

Compare attention patterns for different stocks or market conditions.
This is useful for understanding if the model behaves differently during
high-volatility vs calm periods.

In [None]:
# Example: Compare gate activations for two different inputs
# (Uncomment when using real data)
#
# # Sample 1: calm market period
# x1, _, m1 = dataset[100]
# extractor = AttentionExtractor(model)
# with torch.no_grad():
#     with extractor:
#         _ = model(x1.unsqueeze(0).to(device), m1.unsqueeze(0).to(device))
# attn_calm = extractor.get_data()
#
# # Sample 2: volatile market period
# x2, _, m2 = dataset[500]
# extractor.clear()
# with torch.no_grad():
#     with extractor:
#         _ = model(x2.unsqueeze(0).to(device), m2.unsqueeze(0).to(device))
# attn_volatile = extractor.get_data()
#
# # Compare mean gate values
# fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
# for layer in range(len(attn_calm['gate_values'])):
#     g1 = attn_calm['gate_values'][layer][0].mean(dim=-1).numpy()
#     g2 = attn_volatile['gate_values'][layer][0].mean(dim=-1).numpy()
#     ax1.plot(g1, label=f'L{layer}', alpha=0.8)
#     ax2.plot(g2, label=f'L{layer}', alpha=0.8)
# ax1.set_title('Calm Market — Gate Activations')
# ax2.set_title('Volatile Market — Gate Activations')
# for ax in (ax1, ax2):
#     ax.set_ylim(0, 1); ax.legend(); ax.grid(True, alpha=0.3)
# plt.tight_layout()
# plt.show()

print('Uncomment the cell above when using real data to compare market regimes.')

In [None]:
def analyze_output_distribution(model, x, market_x=None, n_samples=100):
    """Run multiple random samples and analyze the output distribution."""
    model.eval()
    all_logits = []
    
    with torch.no_grad():
        for i in range(n_samples):
            # Use provided inputs or generate random
            if i == 0:
                xi, mi = x, market_x
            else:
                xi = torch.randn_like(x)
                mi = torch.randn_like(market_x) if market_x is not None else None
            
            args = (xi,) if mi is None else (xi, mi)
            out = model(*args)
            all_logits.append(out.cpu())
    
    logits = torch.cat(all_logits, dim=0)  # [n_samples, ...]
    
    if logits.dim() == 1:
        # Regression
        print(f'Regression output: mean={logits.mean():.4f}, std={logits.std():.4f}')
        fig, ax = plt.subplots(figsize=(10, 4))
        ax.hist(logits.numpy(), bins=50, color='steelblue', alpha=0.8, edgecolor='navy')
        ax.set_xlabel('Predicted value')
        ax.set_ylabel('Count')
        ax.set_title(f'Regression Output Distribution (n={n_samples})')
        ax.axvline(0, color='red', linestyle='--', alpha=0.5, label='zero')
        ax.legend()
        plt.show()
        return logits
    
    # Classification — compute probabilities
    if hasattr(model, 'use_coral') and model.use_coral:
        # CORAL: cumulative sigmoid probabilities
        cum_probs = torch.sigmoid(logits)  # [n, K-1]
        ones = torch.ones(cum_probs.size(0), 1)
        zeros = torch.zeros(cum_probs.size(0), 1)
        extended = torch.cat([ones, cum_probs, zeros], dim=1)
        probs = extended[:, :-1] - extended[:, 1:]
        probs = probs.clamp(min=0)
        preds = (cum_probs > 0.5).sum(dim=-1)
        print(f'CORAL output: {logits.shape[1]} thresholds -> {probs.shape[1]} classes')
        
        # Plot cumulative probabilities
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        for i in range(min(10, cum_probs.shape[0])):
            ax1.plot(cum_probs[i].numpy(), alpha=0.5)
        ax1.set_xlabel('Threshold k')
        ax1.set_ylabel('P(class > k)')
        ax1.set_title('CORAL Cumulative Probabilities (first 10 samples)')
        ax1.set_ylim(0, 1)
        ax1.grid(True, alpha=0.3)
    else:
        # Standard softmax
        probs = F.softmax(logits, dim=-1)
        preds = logits.argmax(dim=-1)
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
        # Per-class probability distribution
        for c in range(probs.shape[1]):
            ax1.hist(probs[:, c].numpy(), bins=30, alpha=0.5, label=f'Class {c}')
        ax1.set_xlabel('Probability')
        ax1.set_ylabel('Count')
        ax1.set_title('Per-Class Probability Distribution')
        ax1.legend()
    
    # Prediction distribution
    pred_counts = torch.bincount(preds.long(), minlength=probs.shape[1])
    ax2.bar(range(len(pred_counts)), pred_counts.numpy(), color='steelblue', alpha=0.8)
    ax2.set_xlabel('Predicted class')
    ax2.set_ylabel('Count')
    ax2.set_title(f'Prediction Distribution (n={n_samples})')
    
    fig.tight_layout()
    plt.show()
    
    # Confidence analysis
    max_probs = probs.max(dim=-1).values
    fig, ax = plt.subplots(figsize=(10, 4))
    ax.hist(max_probs.numpy(), bins=50, color='coral', alpha=0.8, edgecolor='darkred')
    ax.set_xlabel('Max probability (confidence)')
    ax.set_ylabel('Count')
    ax.set_title(f'Confidence Distribution — mean={max_probs.mean():.3f}')
    ax.axvline(1.0/probs.shape[1], color='gray', linestyle='--', label='Uniform', alpha=0.5)
    ax.legend()
    plt.show()
    
    # Entropy of predictions
    entropy = -(probs * (probs + 1e-10).log()).sum(dim=-1)
    max_entropy = np.log(probs.shape[1])
    print(f'Confidence: mean={max_probs.mean():.3f}, median={max_probs.median():.3f}')
    print(f'Entropy: mean={entropy.mean():.3f} (max possible={max_entropy:.3f})')
    print(f'Prediction breakdown: {dict(enumerate(pred_counts.tolist()))}')
    
    return probs

probs = analyze_output_distribution(model, x, market_x)

## 13. Output Probability Analysis

Examine the model's output distribution — the post-softmax probabilities
(classification) or raw predictions (regression) — to understand confidence
and calibration.

**What to look for:**
- **Overconfident**: P(class) near 1.0 for most samples = model is overfit or collapsed
- **Underconfident**: P(class) near uniform = model hasn't learned signal
- **Calibration**: Does 80% confidence actually mean 80% accuracy?
- **CORAL ordinal**: Cumulative sigmoid probabilities should be monotonically decreasing