In [None]:
"""
Example: Training TemporalValidator with SignalFlow-NN

Demonstrates the new config-based API for TemporalClassificator.
"""
import signalflow as sf
from signalflow.nn.validator import TemporalValidator
from signalflow.nn.model.temporal_classificator import TrainingConfig
from pathlib import Path
from datetime import datetime
import polars as pl
import torch

# ============================================================================
# 1. Load Raw Data
# ============================================================================

spot_store = sf.data.raw_store.DuckDbSpotStore(db_path=Path("test.duckdb"))
raw_data = sf.data.RawDataFactory.from_duckdb_spot_store(
    spot_store_path=Path("test.duckdb"),
    pairs=["BTCUSDT", "ETHUSDT", "SOLUSDT"],
    start=datetime(2025, 10, 1),
    end=datetime(2025, 12, 31),
    data_types=["spot"],
)
raw_data_view = sf.core.RawDataView(raw_data)

# ============================================================================
# 2. Feature Engineering
# ============================================================================

feature_set = sf.feature.FeatureSet(extractors=[
    sf.feature.pandasta.PandasTaRsiExtractor(length=14),
    sf.feature.pandasta.PandasTaMacdExtractor(fast=12, slow=26, signal=9),
    sf.feature.pandasta.PandasTaAtrExtractor(length=14),
    sf.feature.pandasta.PandasTaBbandsExtractor(length=20, std=2.0),
])
features_df = feature_set.extract(raw_data_view)

# Normalize features per pair
feature_cols = [c for c in features_df.columns if c not in ["pair", "timestamp"]]
features_df = features_df.with_columns([
    ((pl.col(c) - pl.col(c).mean().over("pair")) / (pl.col(c).std().over("pair") + 1e-6))
    .alias(c)
    for c in feature_cols
])

print(f"Features shape: {features_df.shape}")
print(f"Feature columns: {feature_cols}")

# ============================================================================
# 3. Signal Detection
# ============================================================================

detector = sf.detector.SmaCrossSignalDetector(fast_period=10, slow_period=30)
signals = detector.run(raw_data_view)

print(f"Detected {signals.value.height} signals")

# ============================================================================
# 4. Labeling (Ground Truth)
# ============================================================================

from signalflow.target import FixedHorizonLabeler

labeler = FixedHorizonLabeler(
    price_col="close",
    horizon=12,
    out_col="label",
    include_meta=True,
)
labeled_df = labeler.compute(raw_data_view.to_polars("spot"), signals)


def encode_labels(df: pl.DataFrame) -> pl.DataFrame:
    """Encode labels: none=0, rise=1, fall=2"""
    return df.with_columns(
        pl.when(pl.col("label") == "rise").then(1)
        .when(pl.col("label") == "fall").then(2)
        .otherwise(0)
        .cast(pl.Int64)
        .alias("label")
    )


# Join labeled_df with signals to get signal_type
train_signals_df = (
    encode_labels(labeled_df)
    .join(
        signals.value.select(["pair", "timestamp", "signal_type"]),
        on=["pair", "timestamp"],
        how="left",
    )
    .select(["pair", "timestamp", "label", "signal_type"])
    .filter(pl.col("label").is_not_null())
)

print(f"Training signals: {train_signals_df.height}")
print(f"Label distribution:\n{train_signals_df.group_by('label').len()}")

# ============================================================================
# 5. Configure and Train Validator
# ============================================================================

# Determine input size from features
input_size = len(feature_cols)

# Encoder config (LSTM)
encoder_params = {
    "input_size": input_size,
    "hidden_size": 64,
    "num_layers": 2,
    "dropout": 0.2,
    "bidirectional": False,
}

# Head config (MLP classifier)
head_params = {
    "hidden_sizes": [128, 64],
    "dropout": 0.3,
    "activation": "gelu",
}

# Training config (separate from architecture)
training_config = {
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "optimizer": "adamw",
    "scheduler": "reduce_on_plateau",
    "scheduler_patience": 5,
    "label_smoothing": 0.1,
}

# Create validator with new API
validator = TemporalValidator(
    # Architecture
    encoder_type="encoder/lstm",
    encoder_params=encoder_params,
    head_type="head/cls/mlp",
    head_params=head_params,
    
    # Model settings
    window_size=30,
    num_classes=3,
    class_weights=[1.0, 2.0, 2.0],  # Upweight rise/fall vs neutral
    training_config=training_config,
    feature_cols=feature_cols,
    
    # Training settings
    max_epochs=10,
    batch_size=64,
    early_stopping_patience=5,
    train_val_test_split=(0.6, 0.2, 0.2),
    split_strategy="temporal",
    num_workers=4,
)

print(f"\nStarting training on {train_signals_df.height} signals...")
print(f"Input size: {input_size}, Window size: {validator.window_size}")

# Train
validator.fit(
    X_train=features_df,
    y_train=train_signals_df,
    log_dir=Path("./logs/temporal_validator"),
    accelerator="auto",
)

print("Training finished.")

# ============================================================================
# 6. Validate New Signals
# ============================================================================

validated_signals = validator.validate_signals(signals, features_df)

# Rename probability columns for clarity
validated_df = validated_signals.value.with_columns([
    pl.col("probability_none").alias("prob_neutral"),
    pl.col("probability_rise").alias("prob_rise"),
    pl.col("probability_fall").alias("prob_fall"),
])

# ============================================================================
# 7. Analyze Results
# ============================================================================

print("\n" + "=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Top Rise Signals
print("\nTop Rise Signals (high probability):")
rise_signals = (
    validated_df
    .filter(pl.col("signal_type") == "rise")
    .sort("prob_rise", descending=True)
    .select(["timestamp", "pair", "prob_rise", "prob_fall", "prob_neutral"])
    .head(10)
)
print(rise_signals)

# Top Fall Signals
print("\nTop Fall Signals (high probability):")
fall_signals = (
    validated_df
    .filter(pl.col("signal_type") == "fall")
    .sort("prob_fall", descending=True)
    .select(["timestamp", "pair", "prob_rise", "prob_fall", "prob_neutral"])
    .head(10)
)
print(fall_signals)

# High-confidence signals (>70% probability)
HIGH_CONF_THRESHOLD = 0.7

high_conf_rise = validated_df.filter(
    (pl.col("signal_type") == "rise") & (pl.col("prob_rise") > HIGH_CONF_THRESHOLD)
)
high_conf_fall = validated_df.filter(
    (pl.col("signal_type") == "fall") & (pl.col("prob_fall") > HIGH_CONF_THRESHOLD)
)

print(f"\nHigh-confidence signals (>{HIGH_CONF_THRESHOLD*100:.0f}%):")
print(f"  Rise signals: {high_conf_rise.height}")
print(f"  Fall signals: {high_conf_fall.height}")

# ============================================================================
# 8. Save Validator
# ============================================================================

validator.save(Path("./models/temporal_validator.pkl"))
print("\nValidator saved to ./models/temporal_validator.pkl")

# To load later:
# loaded_validator = TemporalValidator.load(Path("./models/temporal_validator.pkl"))