# Phase 4: GNN Regime Detection Validation

**Date:** December 2024

**Objective:** Validate GNN regime detector for crash prediction and drawdown reduction.

**Success Criteria:**
- Detect regime shifts 3-5 days before major drawdowns
- Reduce max drawdown by ≥30%

---

In [None]:
import sys
sys.path.insert(0, '../..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Load Multi-Asset Data & Build Features

In [None]:
from data.ingestion.multi_asset import build_regime_detection_features, RegimeLabeler

df, labels = build_regime_detection_features(
    start_date='2020-01-01',
    assets=['BTC', 'ETH', 'SOL']
)

print(f"Dataset: {len(df)} samples")
print(f"Features: {len(df.columns)} columns")
print(f"\nRegime distribution:")
print(labels.value_counts())

In [None]:
# Visualize regime distribution
fig, ax = plt.subplots(figsize=(8, 5))
colors = {'RISK_ON': 'green', 'CAUTION': 'orange', 'RISK_OFF': 'red'}
labels.value_counts().plot(kind='bar', color=[colors[x] for x in labels.value_counts().index], ax=ax)
ax.set_title('Regime Distribution (2020-2024)')
ax.set_ylabel('Days')
plt.tight_layout()
plt.savefig('regime_distribution.png', dpi=150)
plt.show()

## 2. Train GNN Regime Detector

In [None]:
from models.predictors.regime_gnn import RegimeDetector

detector = RegimeDetector(
    assets=['BTC', 'ETH', 'SOL'],
    hidden_dim=64
)

# Prepare graphs
graphs, targets = detector.prepare_dataset(df, labels)
print(f"Prepared {len(graphs)} graphs")

# Train/val split
split = int(len(graphs) * 0.8)
train_graphs, val_graphs = graphs[:split], graphs[split:]
train_labels, val_labels = targets[:split], targets[split:]

print(f"Train: {len(train_graphs)}, Val: {len(val_graphs)}")

In [None]:
# Class weights
train_counts = np.bincount(train_labels, minlength=3)
class_weights = len(train_labels) / (3 * train_counts + 1)
class_weights = class_weights / class_weights.sum() * 3

print(f"Class weights: {class_weights}")

# Train
history = detector.train(
    train_graphs, train_labels,
    val_graphs, val_labels,
    epochs=100,
    lr=0.001,
    batch_size=32,
    class_weights=class_weights.tolist()
)

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history['train_loss'], label='Train')
axes[0].plot(history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()

axes[1].plot(history['train_acc'], label='Train')
axes[1].plot(history['val_acc'], label='Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training Accuracy')
axes[1].legend()

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()

## 3. Evaluate Classification Performance

In [None]:
preds, probs = detector.predict(val_graphs)

print("Classification Report:")
print(classification_report(val_labels, preds, target_names=detector.REGIME_LABELS))

In [None]:
# Confusion matrix
cm = confusion_matrix(val_labels, preds)

fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=detector.REGIME_LABELS,
            yticklabels=detector.REGIME_LABELS, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('Actual')
ax.set_title('Confusion Matrix')
plt.tight_layout()
plt.savefig('confusion_matrix.png', dpi=150)
plt.show()

## 4. Backtest: Position Sizing by Regime

In [None]:
# Position sizing rules
position_map = {
    'RISK_ON': 1.0,
    'CAUTION': 0.5,
    'RISK_OFF': 0.2
}

# Get BTC returns for validation period
btc_returns = df['BTC_return_1d'].iloc[split:].values[:len(preds)]

# Buy and hold
bh_equity = [1.0]
for ret in btc_returns:
    bh_equity.append(bh_equity[-1] * (1 + ret))

# GNN strategy
strat_equity = [1.0]
for i, ret in enumerate(btc_returns):
    pred_regime = detector.REGIME_LABELS[preds[i]]
    position = position_map[pred_regime]
    strat_equity.append(strat_equity[-1] * (1 + ret * position))

print(f"Buy & Hold final: {bh_equity[-1]:.3f}")
print(f"GNN Strategy final: {strat_equity[-1]:.3f}")

In [None]:
# Plot equity curves
fig, ax = plt.subplots(figsize=(12, 6))

ax.plot(bh_equity, label='Buy & Hold', linewidth=2)
ax.plot(strat_equity, label='GNN Strategy', linewidth=2)
ax.set_xlabel('Days')
ax.set_ylabel('Equity (starting at 1.0)')
ax.set_title('Backtest: GNN Regime-Based Position Sizing')
ax.legend()
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('equity_curves.png', dpi=150)
plt.show()

In [None]:
# Calculate metrics
def calc_metrics(equity):
    equity = np.array(equity)
    returns = np.diff(equity) / equity[:-1]
    
    total_return = (equity[-1] / equity[0] - 1) * 100
    sharpe = returns.mean() / returns.std() * np.sqrt(252) if returns.std() > 0 else 0
    
    peak = np.maximum.accumulate(equity)
    drawdown = (peak - equity) / peak
    max_dd = drawdown.max() * 100
    
    return total_return, sharpe, max_dd

bh_ret, bh_sharpe, bh_dd = calc_metrics(bh_equity)
strat_ret, strat_sharpe, strat_dd = calc_metrics(strat_equity)

dd_reduction = (bh_dd - strat_dd) / bh_dd * 100

print("\n" + "="*60)
print("BACKTEST RESULTS")
print("="*60)
print(f"\n{'Metric':<20} {'Buy & Hold':>15} {'GNN Strategy':>15}")
print("-"*50)
print(f"{'Total Return':<20} {bh_ret:>14.1f}% {strat_ret:>14.1f}%")
print(f"{'Sharpe Ratio':<20} {bh_sharpe:>15.2f} {strat_sharpe:>15.2f}")
print(f"{'Max Drawdown':<20} {bh_dd:>14.1f}% {strat_dd:>14.1f}%")
print(f"\nDrawdown Reduction: {dd_reduction:.1f}%")
print(f"Target: ≥30%")
print(f"Status: {'✓ PASSED' if dd_reduction >= 30 else '✗ FAILED'}")

## 5. Lead Time Analysis

In [None]:
# Find RISK_OFF periods and check lead time
val_labels_series = labels.iloc[split:].values[:len(preds)]

risk_off_starts = []
in_risk_off = False
for i, label in enumerate(val_labels_series):
    if label == 'RISK_OFF' and not in_risk_off:
        risk_off_starts.append(i)
        in_risk_off = True
    elif label != 'RISK_OFF':
        in_risk_off = False

print(f"Found {len(risk_off_starts)} RISK_OFF periods")

# Check detection lead time
lead_times = []
for start_idx in risk_off_starts:
    for lookback in range(1, min(21, start_idx)):
        check_idx = start_idx - lookback
        if detector.REGIME_LABELS[preds[check_idx]] in ['CAUTION', 'RISK_OFF']:
            lead_times.append(lookback)
            print(f"  Period at idx {start_idx}: Detected {lookback} days before")
            break
    else:
        print(f"  Period at idx {start_idx}: NOT detected early")

if lead_times:
    print(f"\nAvg lead time: {np.mean(lead_times):.1f} days")
    print(f"Detection rate: {len(lead_times)}/{len(risk_off_starts)} ({len(lead_times)/len(risk_off_starts)*100:.0f}%)")

## 6. Summary & Conclusions

In [None]:
print("="*60)
print("PHASE 4 SUMMARY: GNN REGIME DETECTION")
print("="*60)

print(f"\n1. CLASSIFICATION PERFORMANCE")
print(f"   Overall accuracy: {(np.array(preds) == np.array(val_labels)).mean()*100:.1f}%")
print(f"   (Poor accuracy, but useful for risk management)")

print(f"\n2. BACKTEST RESULTS")
print(f"   Drawdown reduction: {dd_reduction:.1f}% (target ≥30%)")
print(f"   Return: {strat_ret:.1f}% vs {bh_ret:.1f}% B&H")
print(f"   Sharpe: {strat_sharpe:.2f} vs {bh_sharpe:.2f} B&H")

print(f"\n3. LEAD TIME")
if lead_times:
    print(f"   Avg lead time: {np.mean(lead_times):.1f} days")
    print(f"   Detection rate: {len(lead_times)}/{len(risk_off_starts)}")

print(f"\n4. OVERALL ASSESSMENT")
if dd_reduction >= 30:
    print(f"   ✓ SUCCESS: Target drawdown reduction achieved")
else:
    print(f"   ✗ Target not met")

print("\n" + "="*60)