In [None]:
"""
Example: Training TemporalValidator with SignalFlow-NN

Flow:
1. Load raw data
2. Extract features (RAW values, normalization handled by preprocessor)
3. Detect signals
4. Label signals
5. Configure Preprocessor & Train Validator
6. Validate new signals
"""
import signalflow as sf
from signalflow.nn.validator import TemporalValidator
from signalflow.nn.model.temporal_classificator import TrainingConfig
from signalflow.nn.data.ts_preprocessor import TimeSeriesPreprocessor, ScalerConfig  # <--- NEW IMPORT
from pathlib import Path
from datetime import datetime
import polars as pl
import torch

# ============================================================================
# 1. Load Raw Data
# ============================================================================

raw_data = sf.data.RawDataFactory.from_duckdb_spot_store(
    spot_store_path=Path("test.duckdb"),
    pairs=["SOLUSDT"],
    start=datetime(2025, 1, 1),
    end=datetime(2025, 12, 31),
    data_types=["spot"],
)
raw_data_view = sf.core.RawDataView(raw_data)

# ============================================================================
# 2. Feature Engineering
# ============================================================================

feature_set = sf.feature.FeatureSet(extractors=[
    sf.feature.pandasta.PandasTaRsiExtractor(length=14),
    sf.feature.pandasta.PandasTaMacdExtractor(fast=12, slow=26, signal=9),
    sf.feature.pandasta.PandasTaAtrExtractor(length=14),
    sf.feature.pandasta.PandasTaBbandsExtractor(length=20, std=2.0),
])
features_df = feature_set.extract(raw_data_view)

# Визначаємо список колонок (але НЕ нормалізуємо тут вручну)
feature_cols = [c for c in features_df.columns if c not in ["pair", "timestamp"]]

print(f"Features shape: {features_df.shape} (full history, RAW values)")
print(f"Feature columns ({len(feature_cols)}): {feature_cols[:5]}...")

# ============================================================================
# 3. Signal Detection
# ============================================================================

detector = sf.detector.SmaCrossSignalDetector(fast_period=10, slow_period=30)
signals = detector.run(raw_data_view)

# Filter to actionable signals only
actionable_signals = signals.value.filter(
    pl.col("signal_type").is_in(["rise", "fall"])
)

print(f"\nDetected {signals.value.height} total signals")
print(f"Actionable signals: {actionable_signals.height}")

# ============================================================================
# 4. Labeling
# ============================================================================

from signalflow.target import FixedHorizonLabeler

labeler = FixedHorizonLabeler(
    price_col="close",
    horizon=120,
    out_col="label",
    include_meta=True,
)

spot_df = raw_data_view.to_polars("spot")
labeled_full = labeler.compute(spot_df)

labeled_signals = (
    actionable_signals
    .select(["pair", "timestamp", "signal_type"])
    .join(
        labeled_full.select(["pair", "timestamp", "label"]),
        on=["pair", "timestamp"],
        how="inner",
    )
    .filter(pl.col("label").is_not_null())
)

labeled_signals = labeled_signals.with_columns(
    pl.when(pl.col("label") == "rise").then(1)
    .when(pl.col("label") == "fall").then(2)
    .otherwise(0)
    .cast(pl.Int64)
    .alias("label")
)

print(f"\nLabeled signals: {labeled_signals.height}")

# ============================================================================
# 5. Configure and Train Validator (WITH PREPROCESSOR)
# ============================================================================

input_size = len(feature_cols)

# Encoder config
encoder_params = {
    "input_size": input_size,
    "hidden_size": 64,
    "num_layers": 2,
    "dropout": 0.2,
    "bidirectional": False,
}

# Head config
head_params = {
    "hidden_sizes": [128, 64],
    "dropout": 0.3,
    "activation": "gelu",
}

# Training config
training_config = {
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "optimizer": "adamw",
    "scheduler": "reduce_on_plateau",
    "scheduler_patience": 5,
    "label_smoothing": 0.1,
}

# --- NEW: Configure Preprocessor ---
# Автоматична нормалізація (Robust) по кожній групі (парі) окремо
preprocessor = TimeSeriesPreprocessor(
    default_config=ScalerConfig(method="robust", scope="group"),
    group_col="pair",
    time_col="timestamp"
)
# -----------------------------------

# Create validator
validator = TemporalValidator(
    encoder_type="encoder/gru",
    encoder_params=encoder_params,
    head_type="head/cls/linear",
    head_params=head_params,
    
    preprocessor=preprocessor,  # <--- CONNECTED HERE
    
    window_size=30,
    num_classes=3,
    class_weights=[1.0, 1.0, 1.0],
    training_config=training_config,
    feature_cols=feature_cols,
    max_epochs=100,
    batch_size=64,
    early_stopping_patience=5,
    train_val_test_split=(0.6, 0.2, 0.2),
    split_strategy="temporal",
    num_workers=4,
)

print(f"\n{'='*60}")
print(f"Starting training")
print(f"{'='*60}")

# fit() now handles scaling internally (fit on train -> transform all)
validator.fit(
    X_train=features_df,      
    y_train=labeled_signals,    
    log_dir=Path("./logs/temporal_validator"),
    accelerator="auto",
)

print("\nTraining finished.")

# ============================================================================
# 6. Validate New Signals
# ============================================================================

# Features are automatically scaled using the saved preprocessor state
validated_signals = validator.validate_signals(signals, features_df)

validated_df = validated_signals.value.with_columns([
    pl.col("probability_none").alias("prob_neutral"),
    pl.col("probability_rise").alias("prob_rise"),
    pl.col("probability_fall").alias("prob_fall"),
])

# ============================================================================
# 7. Analyze Results
# ============================================================================

print("\n" + "=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Top Rise Signals
print("\nTop Rise Signals (high probability):")
rise_signals = (
    validated_df
    .filter(pl.col("signal_type") == "rise")
    .sort("prob_rise", descending=True)
    .select(["timestamp", "pair", "prob_rise", "prob_fall", "prob_neutral"])
    .head(10)
)
print(rise_signals)

# High-confidence signals (>70% probability)
HIGH_CONF_THRESHOLD = 0.7
high_conf_rise = validated_df.filter(
    (pl.col("signal_type") == "rise") & (pl.col("prob_rise") > HIGH_CONF_THRESHOLD)
)

print(f"\nHigh-confidence signals (>{HIGH_CONF_THRESHOLD*100:.0f}%):")
print(f"  Rise signals: {high_conf_rise.height}")

# ============================================================================
# 8. Save Validator
# ============================================================================

validator.save(Path("./best_models/temporal_validator.pkl"))
print("\nValidator saved to ./models/temporal_validator.pkl")