# SDE Head + SDE Evolution Block with Patching & RevIN

This notebook trains a model that pairs advanced preprocessing blocks (**Patching** + **RevIN**) 
with the **SDEEvolutionBlock** backbone and **SDEHead** (or **NeuralSDEHead**) for price-path generation.

**New Features:**
- **FlexiblePatchEmbed**: Converts input sequences into patches for more efficient processing
- **RevIN**: Reversible Instance Normalization for handling non-stationary time series
- **ZScoreEngineer**: Rolling z-score features (short/long window) instead of OHLCV features
- **288-step prediction horizon**: Extended from 12 to 288 steps for longer forecasts

**Architecture:**
- **FlexiblePatchEmbed** converts raw sequences into patches with channel independence
- **RevIN** normalizes the patched sequences to handle non-stationarity
- **SDEEvolutionBlock** learns residual stochastic updates inside the backbone
- **SDEHead** maps the backbone output to `(mu, sigma)` drift/volatility parameters
- **NeuralSDEHead** learns full drift/diffusion networks and integrates via `torchsde.sdeint`

**Workflow:**
1. Load OHLCV data from `tensorlink-dev/open-synth-training-data`
2. Engineer z-score features using **ZScoreEngineer**
3. Build a backbone using **FlexiblePatchEmbed + RevIN + TransformerBlock + SDEEvolutionBlock**
4. Attach an **SDEHead** (or **NeuralSDEHead**)
5. Train with CRPS loss and backtest with multi-interval scoring

## 1. Imports & Setup

In [None]:
import sys
import os

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

from src.models.registry import discover_components, registry, SDEEvolutionBlock, TransformerBlock
from src.models.factory import HybridBackbone, SynthModel
from src.models.heads import SDEHead, NeuralSDEHead
from src.models.components.advanced_blocks import RevIN, FlexiblePatchEmbed
from src.data.market_data_loader import (
    HFOHLCVSource,
    MockDataSource,
    ZScoreEngineer,
    MarketDataLoader,
)
from src.research.trainer import Trainer, DataToModelAdapter
from src.research.metrics import (
    crps_ensemble,
    CRPSMultiIntervalScorer,
    SCORING_INTERVALS,
)

discover_components("src/models/components")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"ZScoreEngineer produces 3 features: returns, z_short, z_long")

## 2. Configuration

In [None]:
# ----- Data config -----
REPO_ID = "tensorlink-dev/open-synth-training-data"
ASSET_FILES = {
    "BTC_USD": "data/BTC_USD/5m.parquet",
    "ETH_USD": "data/ETH_USD/5m.parquet",
    "SOL_USD": "data/SOL_USD/5m.parquet",
}
ASSETS = list(ASSET_FILES.keys())
USE_HF = True

INPUT_LEN = 64
PRED_LEN = 288  # Updated to 288 steps
BATCH_SIZE = 8
FEATURE_DIM = 3  # ZScoreEngineer produces 3 features

# ----- Model config -----
D_MODEL = 64
SDE_BLOCK_HIDDEN = 128   # Hidden dim for SDEEvolutionBlock
SDE_HEAD_HIDDEN = 64     # Hidden dim for the SDE head MLPs
N_PATHS = 500

# ----- Patching config -----
PATCH_LEN = 16
PATCH_STRIDE = 8

# ----- Head selection -----
# "sde"        -> SDEHead  (outputs mu, sigma for GBM simulation)
# "neural_sde" -> NeuralSDEHead (learns drift/diffusion networks, integrates via torchsde)
HEAD_TYPE = "sde"

# NeuralSDEHead-specific config
NEURAL_SDE_SOLVER = "euler"   # "euler", "milstein", or "srk"
NEURAL_SDE_ADJOINT = False    # Use adjoint method for memory-efficient backprop

# ----- Training config -----
EPOCHS = 15
LR = 1e-3

# ----- Backtest config -----
TIME_INCREMENT = 3600

## 3. Data Loading

In [None]:
engineer = ZScoreEngineer(short_win=20, long_win=200)

if USE_HF:
    source = HFOHLCVSource(
        repo_id=REPO_ID,
        asset_files=ASSET_FILES,
        repo_type="dataset",
    )
else:
    source = MockDataSource(length=8000, freq="5min", seed=42, base_price=100.0)

loader = MarketDataLoader(
    data_source=source,
    engineer=engineer,
    assets=ASSETS,
    input_len=INPUT_LEN,
    pred_len=PRED_LEN,
    batch_size=BATCH_SIZE,
    feature_dim=FEATURE_DIM,
)

print(f"Assets loaded:  {[a.name for a in loader.assets_data]}")
print(f"Total windows:  {len(loader.dataset)}")
print(f"Sample input shape (F, T): {loader.dataset[0]['inputs'].shape}")
print(f"Sample target shape (1, T): {loader.dataset[0]['target'].shape}")

In [None]:
train_dl, val_dl, test_dl = loader.static_holdout(
    cutoff=0.2,
    val_size=0.15,
    shuffle_train=True,
)

print(f"Train batches: {len(train_dl)}")
print(f"Val batches:   {len(val_dl)}")
print(f"Test batches:  {len(test_dl)}")

## 4. Build the Model

### Backbone: Patching + RevIN + Transformer + SDE Evolution Block

The backbone now includes advanced preprocessing and normalization:

1. **FlexiblePatchEmbed**: Converts the input sequence into patches using strided convolutions.
   This helps the model capture local patterns and reduces computational complexity for long sequences.
   With `channel_independence=True`, each feature channel is processed independently.

2. **RevIN (Reversible Instance Normalization)**: Normalizes the sequence by subtracting the mean
   and dividing by standard deviation. This handles non-stationary time series by making the data
   stationary during processing. The transformation can be reversed after prediction.

3. **TransformerBlock**: Multi-head self-attention + gated MLP for capturing long-range dependencies.

4. **SDEEvolutionBlock**: Learns residual stochastic updates inside the backbone (`x + MLP(x)`),
   giving the latent representation a diffusion-like inductive bias. This complements the
   deterministic attention patterns with learned stochastic perturbations.

### Head: SDEHead or NeuralSDEHead

- **SDEHead**: 2-layer MLP mapping `h_t -> (mu, sigma)`. Paths are simulated
  externally via standard GBM (Euler-Maruyama).
- **NeuralSDEHead**: Learns drift `f(t, y | ctx)` and diffusion `g(t, y | ctx)`
  networks that are integrated using `torchsde.sdeint` in log-price space.

In [None]:
blocks = [
    FlexiblePatchEmbed(
        d_model=D_MODEL,
        patch_len=PATCH_LEN,
        stride=PATCH_STRIDE,
        in_channels=FEATURE_DIM,
        channel_independence=True,
        mask_ratio=0.0,
    ),
    RevIN(d_model=D_MODEL, affine=True, eps=1e-5),
    TransformerBlock(d_model=D_MODEL, nhead=4, dropout=0.1),
    SDEEvolutionBlock(d_model=D_MODEL, hidden=SDE_BLOCK_HIDDEN, dropout=0.1),
    TransformerBlock(d_model=D_MODEL, nhead=4, dropout=0.1),
    SDEEvolutionBlock(d_model=D_MODEL, hidden=SDE_BLOCK_HIDDEN, dropout=0.1),
]

backbone = HybridBackbone(
    input_size=FEATURE_DIM,
    d_model=D_MODEL,
    blocks=blocks,
    validate_shapes=False,  # Patching changes sequence length
)

if HEAD_TYPE == "neural_sde":
    head = NeuralSDEHead(
        latent_size=backbone.output_dim,
        hidden=SDE_HEAD_HIDDEN,
        solver=NEURAL_SDE_SOLVER,
        adjoint=NEURAL_SDE_ADJOINT,
    )
    head_label = f"NeuralSDEHead (solver={NEURAL_SDE_SOLVER}, adjoint={NEURAL_SDE_ADJOINT})"
else:
    head = SDEHead(
        latent_size=backbone.output_dim,
        hidden=SDE_HEAD_HIDDEN,
    )
    head_label = f"SDEHead (hidden={SDE_HEAD_HIDDEN})"

model = SynthModel(backbone=backbone, head=head).to(device)

total_params = sum(p.numel() for p in model.parameters())
backbone_params = sum(p.numel() for p in backbone.parameters())
head_params = sum(p.numel() for p in head.parameters())
print(f"SynthModel with {head_label} built successfully")
print(f"  Backbone: FlexiblePatchEmbed + RevIN + TransformerBlock + SDEEvolutionBlock (x2)")
print(f"  Patch len: {PATCH_LEN}, stride: {PATCH_STRIDE}")
print(f"  Backbone output dim: {backbone.output_dim}")
print(f"  Backbone params:     {backbone_params:,}")
print(f"  Head params:         {head_params:,}")
print(f"  Total parameters:    {total_params:,}")

### Sanity Check: Forward Pass

In [None]:
dummy_x = torch.randn(2, INPUT_LEN, FEATURE_DIM, device=device)
dummy_price = torch.tensor([100.0, 50.0], device=device)

with torch.no_grad():
    paths, mu, sigma = model(dummy_x, initial_price=dummy_price, horizon=PRED_LEN, n_paths=100)

print(f"Paths shape:  {paths.shape}   (batch, n_paths, horizon)")
print(f"Mu shape:     {mu.shape}")
print(f"Sigma shape:  {sigma.shape}")
print(f"Mu values:    {mu.cpu().numpy()}")
print(f"Sigma values: {sigma.cpu().numpy()}")
print(f"Path range:   [{paths.min().item():.4f}, {paths.max().item():.4f}]")

## 5. Training Loop

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)
adapter = DataToModelAdapter(device=device, target_is_log_return=True)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    n_paths=N_PATHS,
    device=device,
    adapter=adapter,
)

history = {
    "train_loss": [],
    "train_crps": [],
    "train_sharpness": [],
    "val_crps": [],
}

In [None]:
for epoch in range(1, EPOCHS + 1):
    epoch_losses = []
    epoch_crps = []
    epoch_sharp = []

    for batch in train_dl:
        metrics = trainer.train_step(batch)
        epoch_losses.append(metrics["loss"])
        epoch_crps.append(metrics["crps"])
        epoch_sharp.append(metrics["sharpness"])

    avg_loss = np.mean(epoch_losses)
    avg_crps = np.mean(epoch_crps)
    avg_sharp = np.mean(epoch_sharp)

    history["train_loss"].append(avg_loss)
    history["train_crps"].append(avg_crps)
    history["train_sharpness"].append(avg_sharp)

    val_metrics = trainer.validate(val_dl)
    history["val_crps"].append(val_metrics["val_crps"])

    if epoch % 3 == 0 or epoch == 1:
        print(
            f"Epoch {epoch:3d}/{EPOCHS}  "
            f"train_loss={avg_loss:.5f}  "
            f"train_crps={avg_crps:.5f}  "
            f"val_crps={val_metrics['val_crps']:.5f}  "
            f"sharpness={avg_sharp:.5f}"
        )

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history["train_loss"], label="Train Loss (CRPS)")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].legend()

axes[1].plot(history["train_crps"], label="Train CRPS")
axes[1].plot(history["val_crps"], label="Val CRPS")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("CRPS")
axes[1].set_title("CRPS: Train vs Validation")
axes[1].legend()

axes[2].plot(history["train_sharpness"], label="Sharpness", color="tab:green")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Std(paths)")
axes[2].set_title("Forecast Sharpness")
axes[2].legend()

plt.tight_layout()
plt.show()

## 6. Backtesting on Test Set

In [None]:
model.eval()
scorer = CRPSMultiIntervalScorer(time_increment=TIME_INCREMENT)

interval_scores = {name: [] for name in SCORING_INTERVALS}
overall_scores = []
all_test_crps = []

for batch in test_dl:
    adapted = adapter(batch)
    history_t = adapted["history"]
    initial_price = adapted["initial_price"]
    target_factors = adapted["target_factors"]
    horizon = target_factors.shape[-1]

    with torch.no_grad():
        paths, mu, sigma = model(
            history_t,
            initial_price=initial_price,
            horizon=horizon,
            n_paths=N_PATHS,
        )

    sim_paths = paths.transpose(1, 2)
    crps_vals = crps_ensemble(sim_paths, target_factors)
    all_test_crps.append(crps_vals.mean().item())

    for sample_idx in range(paths.shape[0]):
        total_crps, detailed = scorer(paths[sample_idx], paths[sample_idx, 0])
        overall_scores.append(total_crps)
        for row in detailed:
            interval_name = row["Interval"]
            if interval_name in interval_scores and row["Increment"] == "Total":
                interval_scores[interval_name].append(float(row["CRPS"]))

avg_test_crps = np.mean(all_test_crps)
print(f"\n{'='*50}")
print(f"BACKTEST RESULTS -- SDE Head + SDE Block")
print(f"{'='*50}")
print(f"Head type:         {HEAD_TYPE}")
print(f"Average Test CRPS: {avg_test_crps:.6f}")
print(f"\nMulti-Interval CRPS Breakdown:")
for name, scores in interval_scores.items():
    if scores:
        print(f"  {name:>12s}: {np.mean(scores):.6f} (n={len(scores)})")
    else:
        print(f"  {name:>12s}: N/A (horizon too short)")

## 7. Fan Chart Visualization

In [None]:
model.eval()
test_batch = next(iter(test_dl))
adapted = adapter(test_batch)

with torch.no_grad():
    paths, mu, sigma = model(
        adapted["history"],
        initial_price=adapted["initial_price"],
        horizon=adapted["target_factors"].shape[-1],
        n_paths=N_PATHS,
    )

paths_np = paths.cpu().numpy()
targets_np = adapted["target_factors"].cpu().numpy()

n_show = min(4, paths_np.shape[0])
fig, axes = plt.subplots(1, n_show, figsize=(5 * n_show, 4), squeeze=False)

for i in range(n_show):
    ax = axes[0, i]
    sample_paths = paths_np[i]
    t = np.arange(sample_paths.shape[1])

    p5 = np.percentile(sample_paths, 5, axis=0)
    p25 = np.percentile(sample_paths, 25, axis=0)
    p50 = np.percentile(sample_paths, 50, axis=0)
    p75 = np.percentile(sample_paths, 75, axis=0)
    p95 = np.percentile(sample_paths, 95, axis=0)

    ax.fill_between(t, p5, p95, alpha=0.15, color="tab:blue", label="P5-P95")
    ax.fill_between(t, p25, p75, alpha=0.3, color="tab:blue", label="P25-P75")
    ax.plot(t, p50, color="tab:blue", linewidth=2, label="Median")
    ax.plot(t, targets_np[i], color="tab:red", linewidth=2, linestyle="--", label="Actual")

    ax.set_title(f"Sample {i}")
    ax.set_xlabel("Horizon Step")
    ax.set_ylabel("Price Factor")
    if i == 0:
        ax.legend(fontsize=8)

plt.suptitle(f"SDE Head + SDE Block Fan Charts (head={HEAD_TYPE})", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Path Distribution & Learned Parameters

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

# Terminal price distribution for sample 0
terminal_prices = paths_np[0, :, -1]
axes[0].hist(terminal_prices, bins=50, alpha=0.7, color="tab:blue", edgecolor="white")
axes[0].axvline(targets_np[0, -1], color="tab:red", linewidth=2, linestyle="--", label="Actual")
axes[0].set_title("Terminal Price Distribution")
axes[0].set_xlabel("Price Factor")
axes[0].set_ylabel("Count")
axes[0].legend()

# Collect mu and sigma across multiple batches
model.eval()
all_mu = []
all_sigma = []
for batch in test_dl:
    adapted_viz = adapter(batch)
    with torch.no_grad():
        _, mu_viz, sigma_viz = model(
            adapted_viz["history"],
            initial_price=adapted_viz["initial_price"],
            horizon=adapted_viz["target_factors"].shape[-1],
            n_paths=10,
        )
    all_mu.append(mu_viz.cpu().numpy().flatten())
    all_sigma.append(sigma_viz.cpu().numpy().flatten())

all_mu = np.concatenate(all_mu)
all_sigma = np.concatenate(all_sigma)

axes[1].hist(all_mu, bins=50, alpha=0.7, color="tab:orange", edgecolor="white")
axes[1].set_title("Learned Drift (mu) Distribution")
axes[1].set_xlabel("mu")
axes[1].set_ylabel("Count")

axes[2].hist(all_sigma, bins=50, alpha=0.7, color="tab:green", edgecolor="white")
axes[2].set_title("Learned Volatility (sigma) Distribution")
axes[2].set_xlabel("sigma")
axes[2].set_ylabel("Count")

plt.tight_layout()
plt.show()

print(f"Drift   -- mean: {all_mu.mean():.6f}, std: {all_mu.std():.6f}")
print(f"Vol     -- mean: {all_sigma.mean():.6f}, std: {all_sigma.std():.6f}")

## 9. Compare SDEHead vs NeuralSDEHead (Optional)

Re-run the notebook with `HEAD_TYPE = "neural_sde"` to compare:

| Aspect | SDEHead | NeuralSDEHead |
|--------|---------|---------------|
| Output | `(mu, sigma)` scalars | Full paths via `torchsde.sdeint` |
| Simulation | External GBM (constant params) | Internal SDE integration (state-dependent) |
| Dynamics | Constant drift/vol per sample | Time-varying, state-dependent drift/vol |
| Speed | Faster (no ODE solve) | Slower (numerical integration) |
| Expressiveness | Lower (constant parameters) | Higher (neural drift and diffusion) |