# DLinear + TimesNet + TimeMixer Training & Backtesting

This notebook demonstrates how to:
1. Load per-asset OHLCV candle data from the `tensorlink-dev/open-synth-training-data` HF dataset
2. Engineer 16 micro-structure features per 1-hour bar using `OHLCVEngineer`
3. Build a hybrid model using **DLinearBlock**, **TimesNetBlock**, and **TimeMixerBlock**
4. Train with CRPS loss and backtest with multi-interval scoring

## 1. Imports & Setup

In [None]:
import sys
import os

# Ensure the project root is on the path
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
import matplotlib.pyplot as plt

from src.models.registry import discover_components, registry
from src.models.factory import HybridBackbone, SynthModel
from src.models.heads import GBMHead
from src.data.market_data_loader import (
    HFOHLCVSource,
    MockDataSource,
    OHLCVEngineer,
    OHLCV_FEATURE_NAMES,
    MarketDataLoader,
)
from src.research.trainer import Trainer, DataToModelAdapter
from src.research.metrics import (
    crps_ensemble,
    CRPSMultiIntervalScorer,
    SCORING_INTERVALS,
)

# Auto-discover all registered blocks
discover_components("src/models/components")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Registered blocks: {list(registry.blocks.keys())}")
print(f"OHLCV feature count: {len(OHLCV_FEATURE_NAMES)}")
print(f"Features: {OHLCV_FEATURE_NAMES}")

## 2. Configuration

In [None]:
# ----- Data config -----
REPO_ID = "tensorlink-dev/open-synth-training-data"

# Map asset names to their parquet paths in the HF repo.
# The repo has per-asset folders: BTC_USD/, ETH_USD/, PAXG_USD/, SOL_USD/
# Adjust the filenames to match the actual files in each folder.
ASSET_FILES = {
    "BTC_USD": "BTC_USD/data.parquet",
    "ETH_USD": "ETH_USD/data.parquet",
    "SOL_USD": "SOL_USD/data.parquet",
}
ASSETS = list(ASSET_FILES.keys())

# Set USE_HF = False to fall back to MockDataSource for offline testing
USE_HF = True

INPUT_LEN = 64           # Context window (in 1-hour bars)
PRED_LEN = 12            # Prediction horizon (1-hour bars)
BATCH_SIZE = 8
FEATURE_DIM = 16         # OHLCVEngineer produces 16 features

# ----- Model config -----
D_MODEL = 64             # Internal latent dimension
DLINEAR_KERNEL = 25      # Moving-average kernel for DLinear decomposition
TIMESNET_TOP_K = 3       # Number of dominant periods for TimesNet
TIMEMIXER_DOWN_LAYERS = 2  # Number of downsampling scales for TimeMixer
TIMEMIXER_DOWN_WINDOW = 2  # Downsampling factor per layer
N_PATHS = 500            # GBM simulation paths per sample

# ----- Training config -----
EPOCHS = 15
LR = 1e-3

# ----- Backtest config -----
TIME_INCREMENT = 3600    # Seconds per step (1-hour bars)

## 3. Data Loading

`HFOHLCVSource` downloads per-asset parquet files from the HF dataset and returns `AssetData`
with full OHLCV columns. `OHLCVEngineer` then resamples the raw candles (1m/5m) to 1-hour bars
and computes 16 micro-structure features:

| Feature | Description |
|---------|-------------|
| `open, high, low, close, volume` | Standard 1h OHLCV |
| `realized_vol` | Intra-hour log-return std |
| `skew, kurtosis` | Higher moments of intra-hour returns |
| `parkinson_vol` | Range-based volatility estimator |
| `efficiency` | Fractal efficiency (net move / total path) |
| `vwap_dev` | Close deviation from VWAP |
| `signed_vol_sum` | Net buying pressure |
| `up_wick, down_wick` | Upper/lower wick ratios |
| `body_size` | Candle body as fraction of range |
| `clv` | Close Location Value (-1 to +1) |

Set `USE_HF = False` in the config cell to fall back to `MockDataSource` for offline testing.

In [None]:
engineer = OHLCVEngineer(resample_rule="1h")

if USE_HF:
    source = HFOHLCVSource(
        repo_id=REPO_ID,
        asset_files=ASSET_FILES,
        repo_type="dataset",
    )
else:
    # Offline fallback: MockDataSource generates synthetic random-walk prices.
    # OHLCVEngineer gracefully handles the single-price case (O=H=L=C=price).
    source = MockDataSource(length=8000, freq="1min", seed=42, base_price=100.0)

loader = MarketDataLoader(
    data_source=source,
    engineer=engineer,
    assets=ASSETS,
    input_len=INPUT_LEN,
    pred_len=PRED_LEN,
    batch_size=BATCH_SIZE,
    feature_dim=FEATURE_DIM,
)

print(f"Assets loaded:  {[a.name for a in loader.assets_data]}")
print(f"Total windows:  {len(loader.dataset)}")
print(f"Sample input shape (F, T): {loader.dataset[0]['inputs'].shape}")
print(f"Sample target shape (1, T): {loader.dataset[0]['target'].shape}")

### Train / Validation / Test Split

Using a static holdout with a fractional cutoff to create leak-safe temporal splits.

In [None]:
train_dl, val_dl, test_dl = loader.static_holdout(
    cutoff=0.2,           # Last 20% of data reserved for test
    val_size=0.15,        # 15% of pre-cutoff data for validation
    shuffle_train=True,
)

print(f"Train batches: {len(train_dl)}")
print(f"Val batches:   {len(val_dl)}")
print(f"Test batches:  {len(test_dl)}")

## 4. Build the DLinear + TimesNet + TimeMixer Model

The backbone stacks four blocks that each bring a different inductive bias:

1. **DLinearBlock** - decomposes input into trend (moving average) and seasonal (residual), applies separate linear layers to each, then sums them
2. **TimesNetBlock** - discovers dominant periods via FFT, reshapes 1D sequences into 2D grids (period x segments), applies Inception-style 2D convolutions, then folds back
3. **TimeMixerBlock** - builds multi-scale representations via average-pool downsampling, applies bottom-up seasonal mixing and top-down trend mixing across scales, then recombines at original resolution
4. **DLinearBlock** - a final decomposition pass to refine the representation

The `HybridBackbone` projects inputs to `d_model`, runs them through the block stack,
and extracts the last time-step embedding. A `GBMHead` then maps the latent to
drift (mu) and volatility (sigma) for GBM path simulation.

In [None]:
# Retrieve block classes from the registry
DLinearBlock = registry.get_block("dlinearblock")
TimesNetBlock = registry.get_block("timesnetblock")
TimeMixerBlock = registry.get_block("timemixerblock")

# Instantiate the blocks
blocks = [
    DLinearBlock(d_model=D_MODEL, kernel_size=DLINEAR_KERNEL),
    TimesNetBlock(d_model=D_MODEL, top_k=TIMESNET_TOP_K, dropout=0.1),
    TimeMixerBlock(
        d_model=D_MODEL,
        down_sampling_window=TIMEMIXER_DOWN_WINDOW,
        down_sampling_layers=TIMEMIXER_DOWN_LAYERS,
        moving_avg_kernel=DLINEAR_KERNEL,
        dropout=0.1,
    ),
    DLinearBlock(d_model=D_MODEL, kernel_size=DLINEAR_KERNEL),
]

backbone = HybridBackbone(
    input_size=FEATURE_DIM,
    d_model=D_MODEL,
    blocks=blocks,
    validate_shapes=True,
)

head = GBMHead(latent_size=backbone.output_dim)

model = SynthModel(backbone=backbone, head=head).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"DLinear + TimesNet + TimeMixer SynthModel built successfully")
print(f"  Backbone output dim: {backbone.output_dim}")
print(f"  Total parameters:    {total_params:,}")

### Sanity Check: Forward Pass

In [None]:
# Quick smoke test with a dummy batch
dummy_x = torch.randn(2, INPUT_LEN, FEATURE_DIM, device=device)
dummy_price = torch.tensor([100.0, 50.0], device=device)

with torch.no_grad():
    paths, mu, sigma = model(dummy_x, initial_price=dummy_price, horizon=PRED_LEN, n_paths=100)

print(f"Paths shape:  {paths.shape}  (batch, n_paths, horizon)")
print(f"Mu:           {mu.detach().cpu().numpy()}")
print(f"Sigma:        {sigma.detach().cpu().numpy()}")

## 5. Training Loop

Uses the `Trainer` class which handles:
- `DataToModelAdapter` to bridge `MarketDataLoader` batch format to `SynthModel` inputs
- CRPS loss for probabilistic calibration
- Sharpness and log-likelihood tracking

In [None]:
optimizer = optim.Adam(model.parameters(), lr=LR)
adapter = DataToModelAdapter(device=device, target_is_log_return=True)

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    n_paths=N_PATHS,
    device=device,
    adapter=adapter,
)

# Metric tracking
history = {
    "train_loss": [],
    "train_crps": [],
    "train_sharpness": [],
    "val_crps": [],
}

In [None]:
for epoch in range(1, EPOCHS + 1):
    # --- Training ---
    epoch_losses = []
    epoch_crps = []
    epoch_sharp = []

    for batch in train_dl:
        metrics = trainer.train_step(batch)
        epoch_losses.append(metrics["loss"])
        epoch_crps.append(metrics["crps"])
        epoch_sharp.append(metrics["sharpness"])

    avg_loss = np.mean(epoch_losses)
    avg_crps = np.mean(epoch_crps)
    avg_sharp = np.mean(epoch_sharp)

    history["train_loss"].append(avg_loss)
    history["train_crps"].append(avg_crps)
    history["train_sharpness"].append(avg_sharp)

    # --- Validation ---
    val_metrics = trainer.validate(val_dl)
    history["val_crps"].append(val_metrics["val_crps"])

    if epoch % 3 == 0 or epoch == 1:
        print(
            f"Epoch {epoch:3d}/{EPOCHS}  "
            f"train_loss={avg_loss:.5f}  "
            f"train_crps={avg_crps:.5f}  "
            f"val_crps={val_metrics['val_crps']:.5f}  "
            f"sharpness={avg_sharp:.5f}"
        )

### Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

axes[0].plot(history["train_loss"], label="Train Loss (CRPS)")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].legend()

axes[1].plot(history["train_crps"], label="Train CRPS")
axes[1].plot(history["val_crps"], label="Val CRPS")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("CRPS")
axes[1].set_title("CRPS: Train vs Validation")
axes[1].legend()

axes[2].plot(history["train_sharpness"], label="Sharpness", color="tab:green")
axes[2].set_xlabel("Epoch")
axes[2].set_ylabel("Std(paths)")
axes[2].set_title("Forecast Sharpness")
axes[2].legend()

plt.tight_layout()
plt.show()

## 6. Backtesting on Test Set

We evaluate the trained model on the held-out test split using the `CRPSMultiIntervalScorer`
which computes CRPS at the standard scoring intervals (5min, 30min, 3hour, 24hour).

In [None]:
model.eval()
scorer = CRPSMultiIntervalScorer(time_increment=TIME_INCREMENT)

interval_scores = {name: [] for name in SCORING_INTERVALS}
overall_scores = []
all_test_crps = []

for batch in test_dl:
    adapted = adapter(batch)
    history_t = adapted["history"]
    initial_price = adapted["initial_price"]
    target_factors = adapted["target_factors"]
    horizon = target_factors.shape[-1]

    with torch.no_grad():
        paths, mu, sigma = model(
            history_t,
            initial_price=initial_price,
            horizon=horizon,
            n_paths=N_PATHS,
        )

    # Ensemble CRPS on the target factors
    sim_paths = paths.transpose(1, 2)  # (batch, horizon, n_paths)
    crps_vals = crps_ensemble(sim_paths, target_factors)
    all_test_crps.append(crps_vals.mean().item())

    # Multi-interval CRPS per sample
    for sample_idx in range(paths.shape[0]):
        total_crps, detailed = scorer(paths[sample_idx], paths[sample_idx, 0])  # score against median path
        overall_scores.append(total_crps)
        for row in detailed:
            interval_name = row["Interval"]
            if interval_name in interval_scores and row["Increment"] == "Total":
                interval_scores[interval_name].append(float(row["CRPS"]))

avg_test_crps = np.mean(all_test_crps)
print(f"\n{'='*50}")
print(f"BACKTEST RESULTS")
print(f"{'='*50}")
print(f"Average Test CRPS: {avg_test_crps:.6f}")
print(f"\nMulti-Interval CRPS Breakdown:")
for name, scores in interval_scores.items():
    if scores:
        print(f"  {name:>12s}: {np.mean(scores):.6f} (n={len(scores)})")
    else:
        print(f"  {name:>12s}: N/A (horizon too short for this interval)")

## 7. Fan Chart Visualization

Visualize the probabilistic forecasts as fan charts with P5/P50/P95 percentile bands.

In [None]:
# Grab a few test samples for visualization
model.eval()
test_batch = next(iter(test_dl))
adapted = adapter(test_batch)

with torch.no_grad():
    paths, mu, sigma = model(
        adapted["history"],
        initial_price=adapted["initial_price"],
        horizon=adapted["target_factors"].shape[-1],
        n_paths=N_PATHS,
    )

paths_np = paths.cpu().numpy()           # (batch, n_paths, horizon)
targets_np = adapted["target_factors"].cpu().numpy()  # (batch, horizon)

n_show = min(4, paths_np.shape[0])
fig, axes = plt.subplots(1, n_show, figsize=(5 * n_show, 4), squeeze=False)

for i in range(n_show):
    ax = axes[0, i]
    sample_paths = paths_np[i]  # (n_paths, horizon)
    t = np.arange(sample_paths.shape[1])

    p5 = np.percentile(sample_paths, 5, axis=0)
    p25 = np.percentile(sample_paths, 25, axis=0)
    p50 = np.percentile(sample_paths, 50, axis=0)
    p75 = np.percentile(sample_paths, 75, axis=0)
    p95 = np.percentile(sample_paths, 95, axis=0)

    ax.fill_between(t, p5, p95, alpha=0.15, color="tab:blue", label="P5-P95")
    ax.fill_between(t, p25, p75, alpha=0.3, color="tab:blue", label="P25-P75")
    ax.plot(t, p50, color="tab:blue", linewidth=2, label="Median")
    ax.plot(t, targets_np[i], color="tab:red", linewidth=2, linestyle="--", label="Actual")

    ax.set_title(f"Sample {i}")
    ax.set_xlabel("Horizon Step")
    ax.set_ylabel("Price Factor")
    if i == 0:
        ax.legend(fontsize=8)

plt.suptitle("DLinear + TimesNet + TimeMixer Fan Charts (Test Set)", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## 8. Path Distribution Analysis

In [None]:
# Terminal price distribution for the first test sample
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

terminal_prices = paths_np[0, :, -1]
axes[0].hist(terminal_prices, bins=50, alpha=0.7, color="tab:blue", edgecolor="white")
axes[0].axvline(targets_np[0, -1], color="tab:red", linewidth=2, linestyle="--", label="Actual")
axes[0].set_title("Terminal Price Distribution")
axes[0].set_xlabel("Price Factor")
axes[0].set_ylabel("Count")
axes[0].legend()

# Mu and sigma learned values across test set
mus = []
sigmas = []
model.eval()
for batch in test_dl:
    adapted_b = adapter(batch)
    with torch.no_grad():
        _, mu_b, sigma_b = model(
            adapted_b["history"],
            initial_price=adapted_b["initial_price"],
            horizon=adapted_b["target_factors"].shape[-1],
            n_paths=10,
        )
    mus.extend(mu_b.cpu().numpy().tolist())
    sigmas.extend(sigma_b.cpu().numpy().tolist())

axes[1].scatter(mus, sigmas, alpha=0.5, s=15)
axes[1].set_xlabel("Drift (mu)")
axes[1].set_ylabel("Volatility (sigma)")
axes[1].set_title("Learned GBM Parameters (Test Set)")

plt.tight_layout()
plt.show()

## 9. Summary

This notebook demonstrated the full training and backtesting workflow within the open-synth-miner framework:

- **Data**: `HFOHLCVSource` loads per-asset parquet files from `tensorlink-dev/open-synth-training-data` (BTC_USD, ETH_USD, SOL_USD, PAXG_USD). `OHLCVEngineer` resamples raw candles to 1-hour bars and computes 16 micro-structure features (realized vol, Parkinson vol, skew, kurtosis, fractal efficiency, VWAP deviation, wick ratios, body size, CLV, signed volume)
- **Model**: `HybridBackbone` with three complementary blocks, mapped to GBM simulation via `GBMHead`
  - `DLinearBlock` — trend-seasonal decomposition with separate linear projections
  - `TimesNetBlock` — FFT period discovery + 2D Inception convolutions for multi-period variation modeling
  - `TimeMixerBlock` — multi-scale past-decomposable mixing with bottom-up seasonal and top-down trend aggregation
- **Training**: CRPS-optimized training via the `Trainer` class
- **Backtesting**: Multi-interval CRPS evaluation at standard scoring intervals (5min, 30min, 3hour, 24hour)

To extend this for production use:
- Add PAXG_USD to the asset list for gold-backed token exposure
- Use `walk_forward()` or `hybrid_nested()` validation strategies for more robust evaluation
- Enable W&B logging via the `ExperimentManager`
- Push trained models to Hugging Face Hub via `HubManager`