# FEDformer Training & Backtesting

This notebook trains a **standalone FEDformer** (FourierBlock) model for price path generation.

FourierBlock is a frequency-enhanced block inspired by FEDformer that applies learned
complex-valued weights in the Fourier domain to capture periodic patterns.

1. Load per-asset OHLCV candle data from `tensorlink-dev/open-synth-training-data`
2. Engineer 16 micro-structure features per 1-hour bar using `OHLCVEngineer`
3. Build a model using **FourierBlock** backbone
4. Train with CRPS loss and backtest with multi-interval scoring

## 1. Imports & Setup

In [None]:
import sys
import os

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

from src.models.registry import discover_components, registry
from src.models.factory import HybridBackbone, SynthModel
from src.models.heads import HorizonHead, NeuralBridgeHead
from src.data.market_data_loader import (
    HFOHLCVSource,
    MockDataSource,
    OHLCVEngineer,
    OHLCV_FEATURE_NAMES,
    MarketDataLoader,
)
from src.research.trainer import Trainer, DataToModelAdapter
from src.research.metrics import (
    crps_ensemble,
    CRPSMultiIntervalScorer,
    SCORING_INTERVALS,
)

discover_components("src/models/components")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"OHLCV feature count: {len(OHLCV_FEATURE_NAMES)}")

## 2. Configuration

In [None]:
# ----- Data config -----
REPO_ID = "tensorlink-dev/open-synth-training-data"
ASSET_FILES = {
    "BTC_USD": "data/BTC_USD/5m.parquet",
    "ETH_USD": "data/ETH_USD/5m.parquet",
    "SOL_USD": "data/SOL_USD/5m.parquet",
}
ASSETS = list(ASSET_FILES.keys())
USE_HF = True

INPUT_LEN = 64
PRED_LEN = 12
BATCH_SIZE = 8
FEATURE_DIM = 16

# ----- Model config -----
D_MODEL = 64
FOURIER_MODES = 32       # Number of Fourier modes to keep
N_PATHS = 500

# ----- Head selection -----
HEAD_TYPE = "neural_bridge"  # "horizon" or "neural_bridge"

# HorizonHead config
HORIZON_MAX = 48
HORIZON_NHEAD = 4
HORIZON_LAYERS = 2

# NeuralBridgeHead config
MICRO_STEPS = 12
BRIDGE_HIDDEN = 64

# ----- Training config -----
EPOCHS = 15
LR = 1e-3

# ----- Backtest config -----
TIME_INCREMENT = 3600

## 3. Data Loading

In [None]:
engineer = OHLCVEngineer(resample_rule="1h")

if USE_HF:
    source = HFOHLCVSource(
        repo_id=REPO_ID,
        asset_files=ASSET_FILES,
        repo_type="dataset",
    )
else:
    source = MockDataSource(length=8000, freq="5min", seed=42, base_price=100.0)

loader = MarketDataLoader(
    data_source=source,
    engineer=engineer,
    assets=ASSETS,
    input_len=INPUT_LEN,
    pred_len=PRED_LEN,
    batch_size=BATCH_SIZE,
    feature_dim=FEATURE_DIM,
)

print(f"Assets loaded:  {[a.name for a in loader.assets_data]}")
print(f"Total windows:  {len(loader.dataset)}")
print(f"Sample input shape (F, T): {loader.dataset[0]['inputs'].shape}")
print(f"Sample target shape (1, T): {loader.dataset[0]['target'].shape}")

In [None]:
train_dl, val_dl, test_dl = loader.static_holdout(
    cutoff=0.2,
    val_size=0.15,
    shuffle_train=True,
)

print(f"Train batches: {len(train_dl)}")
print(f"Val batches:   {len(val_dl)}")
print(f"Test batches:  {len(test_dl)}")

## 4. Build the Model

### FEDformer Backbone

Two stacked FourierBlocks that apply learned complex-valued weights in the
frequency domain. Each block:
1. Applies FFT along the sequence dimension
2. Multiplies the first `modes` Fourier coefficients by learned complex weights
3. Applies inverse FFT to return to the time domain
4. Adds a residual connection

In [None]:
FourierBlock = registry.get_block("fourierblock")

blocks = [
    FourierBlock(d_model=D_MODEL, modes=FOURIER_MODES),
    FourierBlock(d_model=D_MODEL, modes=FOURIER_MODES),
]

backbone = HybridBackbone(
    input_size=FEATURE_DIM,
    d_model=D_MODEL,
    blocks=blocks,
    validate_shapes=True,
)

if HEAD_TYPE == "neural_bridge":
    head = NeuralBridgeHead(
        latent_size=backbone.output_dim,
        micro_steps=MICRO_STEPS,
        hidden_dim=BRIDGE_HIDDEN,
    )
    head_label = f"NeuralBridgeHead (micro_steps={MICRO_STEPS})"
else:
    head = HorizonHead(
        latent_size=backbone.output_dim,
        horizon_max=HORIZON_MAX,
        nhead=HORIZON_NHEAD,
        n_layers=HORIZON_LAYERS,
        dropout=0.1,
    )
    head_label = f"HorizonHead (horizon_max={HORIZON_MAX})"

model = SynthModel(backbone=backbone, head=head).to(device)

total_params = sum(p.numel() for p in model.parameters())
backbone_params = sum(p.numel() for p in backbone.parameters())
head_params = sum(p.numel() for p in head.parameters())
print(f"SynthModel with {head_label} built successfully")
print(f"  Backbone output dim: {backbone.output_dim}")
print(f"  Backbone params:     {backbone_params:,}")
print(f"  Head params:         {head_params:,}")
print(f"  Total parameters:    {total_params:,}")

### Sanity Check: Forward Pass

In [None]:
dummy_x = torch.randn(2, INPUT_LEN, FEATURE_DIM, device=device)
dummy_price = torch.tensor([100.0, 50.0], device=device)

with torch.no_grad():
    paths, param_a, param_b = model(dummy_x, initial_price=dummy_price, horizon=PRED_LEN, n_paths=100)

if HEAD_TYPE == "neural_bridge":
    print(f"Micro path shape:  {paths.shape}      (batch, micro_steps)")
    print(f"Macro return shape: {param_a.shape}    (batch, 1)")
else:
    print(f"Paths shape:     {paths.shape}      (batch, n_paths, horizon)")
    print(f"Mu_seq shape:    {param_a.shape}    (batch, horizon)")
    print(f"Sigma_seq shape: {param_b.shape}  (batch, horizon)")

## 5. Training Loop

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)
adapter = DataToModelAdapter(device=device, target_is_log_return=True)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    n_paths=N_PATHS,
    device=device,
    adapter=adapter,
)

history = {
    "train_loss": [],
    "train_crps": [],
    "train_sharpness": [],
    "val_crps": [],
}

In [None]:
for epoch in range(1, EPOCHS + 1):
    epoch_losses = []
    epoch_crps = []
    epoch_sharp = []

    for batch in train_dl:
        metrics = trainer.train_step(batch)
        epoch_losses.append(metrics["loss"])
        epoch_crps.append(metrics["crps"])
        epoch_sharp.append(metrics["sharpness"])

    avg_loss = np.mean(epoch_losses)
    avg_crps = np.mean(epoch_crps)
    avg_sharp = np.mean(epoch_sharp)

    history["train_loss"].append(avg_loss)
    history["train_crps"].append(avg_crps)
    history["train_sharpness"].append(avg_sharp)

    val_metrics = trainer.validate(val_dl)
    history["val_crps"].append(val_metrics["val_crps"])

    if epoch % 3 == 0 or epoch == 1:
        print(
            f"Epoch {epoch:3d}/{EPOCHS}  "
            f"train_loss={avg_loss:.5f}  "
            f"train_crps={avg_crps:.5f}  "
            f"val_crps={val_metrics['val_crps']:.5f}  "
            f"sharpness={avg_sharp:.5f}"
        )

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history["train_loss"], label="Train Loss (CRPS)")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].legend()

axes[1].plot(history["train_crps"], label="Train CRPS")
axes[1].plot(history["val_crps"], label="Val CRPS")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("CRPS")
axes[1].set_title("CRPS: Train vs Validation")
axes[1].legend()

axes[2].plot(history["train_sharpness"], label="Sharpness", color="tab:green")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Std(paths)")
axes[2].set_title("Forecast Sharpness")
axes[2].legend()

plt.tight_layout()
plt.show()

## 6. Backtesting on Test Set

In [None]:
model.eval()
scorer = CRPSMultiIntervalScorer(time_increment=TIME_INCREMENT)

interval_scores = {name: [] for name in SCORING_INTERVALS}
overall_scores = []
all_test_crps = []

for batch in test_dl:
    adapted = adapter(batch)
    history_t = adapted["history"]
    initial_price = adapted["initial_price"]
    target_factors = adapted["target_factors"]
    horizon = target_factors.shape[-1]

    with torch.no_grad():
        paths, mu, sigma = model(
            history_t,
            initial_price=initial_price,
            horizon=horizon,
            n_paths=N_PATHS,
        )

    sim_paths = paths.transpose(1, 2)
    crps_vals = crps_ensemble(sim_paths, target_factors)
    all_test_crps.append(crps_vals.mean().item())

    for sample_idx in range(paths.shape[0]):
        total_crps, detailed = scorer(paths[sample_idx], paths[sample_idx, 0])
        overall_scores.append(total_crps)
        for row in detailed:
            interval_name = row["Interval"]
            if interval_name in interval_scores and row["Increment"] == "Total":
                interval_scores[interval_name].append(float(row["CRPS"]))

avg_test_crps = np.mean(all_test_crps)
print(f"\n{'='*50}")
print(f"BACKTEST RESULTS â€” FEDformer (FourierBlock)")
print(f"{'='*50}")
print(f"Average Test CRPS: {avg_test_crps:.6f}")
print(f"\nMulti-Interval CRPS Breakdown:")
for name, scores in interval_scores.items():
    if scores:
        print(f"  {name:>12s}: {np.mean(scores):.6f} (n={len(scores)})")
    else:
        print(f"  {name:>12s}: N/A (horizon too short)")

## 7. Fan Chart Visualization

In [None]:
model.eval()
test_batch = next(iter(test_dl))
adapted = adapter(test_batch)

with torch.no_grad():
    paths, mu, sigma = model(
        adapted["history"],
        initial_price=adapted["initial_price"],
        horizon=adapted["target_factors"].shape[-1],
        n_paths=N_PATHS,
    )

paths_np = paths.cpu().numpy()
targets_np = adapted["target_factors"].cpu().numpy()

n_show = min(4, paths_np.shape[0])
fig, axes = plt.subplots(1, n_show, figsize=(5 * n_show, 4), squeeze=False)

for i in range(n_show):
    ax = axes[0, i]
    sample_paths = paths_np[i]
    t = np.arange(sample_paths.shape[1])

    p5 = np.percentile(sample_paths, 5, axis=0)
    p25 = np.percentile(sample_paths, 25, axis=0)
    p50 = np.percentile(sample_paths, 50, axis=0)
    p75 = np.percentile(sample_paths, 75, axis=0)
    p95 = np.percentile(sample_paths, 95, axis=0)

    ax.fill_between(t, p5, p95, alpha=0.15, color="tab:blue", label="P5-P95")
    ax.fill_between(t, p25, p75, alpha=0.3, color="tab:blue", label="P25-P75")
    ax.plot(t, p50, color="tab:blue", linewidth=2, label="Median")
    ax.plot(t, targets_np[i], color="tab:red", linewidth=2, linestyle="--", label="Actual")

    ax.set_title(f"Sample {i}")
    ax.set_xlabel("Horizon Step")
    ax.set_ylabel("Price Factor")
    if i == 0:
        ax.legend(fontsize=8)

plt.suptitle("FEDformer (FourierBlock) Fan Charts (Test Set)", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Path Distribution Analysis

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(16, 4))

terminal_prices = paths_np[0, :, -1]
axes[0].hist(terminal_prices, bins=50, alpha=0.7, color="tab:blue", edgecolor="white")
axes[0].axvline(targets_np[0, -1], color="tab:red", linewidth=2, linestyle="--", label="Actual")
axes[0].set_title("Terminal Price Distribution")
axes[0].set_xlabel("Price Factor")
axes[0].set_ylabel("Count")
axes[0].legend()

model.eval()
test_batch_viz = next(iter(test_dl))
adapted_viz = adapter(test_batch_viz)
with torch.no_grad():
    _, mu_viz, sigma_viz = model(
        adapted_viz["history"],
        initial_price=adapted_viz["initial_price"],
        horizon=adapted_viz["target_factors"].shape[-1],
        n_paths=10,
    )
mu_np = mu_viz.cpu().numpy()
sigma_np = sigma_viz.cpu().numpy()
t_steps = np.arange(mu_np.shape[-1]) if mu_np.ndim > 1 else np.array([0])

for i in range(min(4, mu_np.shape[0])):
    if mu_np.ndim > 1:
        axes[1].plot(t_steps, mu_np[i], alpha=0.6, label=f"Sample {i}" if i < 3 else None)
        axes[2].plot(t_steps, sigma_np[i], alpha=0.6, label=f"Sample {i}" if i < 3 else None)
    else:
        axes[1].axhline(mu_np[i], alpha=0.6)
        axes[2].axhline(sigma_np[i], alpha=0.6)

axes[1].set_xlabel("Horizon Step")
axes[1].set_ylabel("Drift (mu_t)")
axes[1].set_title("Learned Per-Step Drift")
axes[1].legend(fontsize=8)

axes[2].set_xlabel("Horizon Step")
axes[2].set_ylabel("Volatility (sigma_t)")
axes[2].set_title("Learned Per-Step Volatility")
axes[2].legend(fontsize=8)

plt.tight_layout()
plt.show()