In [None]:
"""
Example: Training TemporalValidator with SignalFlow-NN

Flow:
1. Load raw data
2. Extract features (RAW values, normalization handled by preprocessor)
3. Detect signals
4. Label signals
5. Configure Preprocessor & Train Validator
6. Validate new signals
"""
import signalflow as sf
from signalflow.nn.validator import TemporalValidator
from signalflow.nn.model.temporal_classificator import TrainingConfig
from signalflow.nn.data.ts_preprocessor import TimeSeriesPreprocessor, ScalerConfig  # <--- NEW IMPORT
from pathlib import Path
from datetime import datetime
import polars as pl
import torch

# ============================================================================
# 1. Load Raw Data
# ============================================================================

raw_data = sf.data.RawDataFactory.from_duckdb_spot_store(
    spot_store_path=Path("test.duckdb"),
    pairs=["BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT"],
    start=datetime(2024, 1, 1),
    end=datetime(2025, 12, 31),
    data_types=["spot"],
)
raw_data_view = sf.core.RawDataView(raw_data)

# ============================================================================
# 2. Feature Engineering
# ============================================================================

feature_set = sf.feature.FeatureSet(extractors=[
    sf.feature.pandasta.PandasTaRsiExtractor(length=14),
    sf.feature.pandasta.PandasTaMacdExtractor(fast=12, slow=26, signal=9),
    sf.feature.pandasta.PandasTaAtrExtractor(length=14),
    sf.feature.pandasta.PandasTaBbandsExtractor(length=20, std=2.0),
])
features_df = feature_set.extract(raw_data_view)

# Визначаємо список колонок (але НЕ нормалізуємо тут вручну)
feature_cols = [c for c in features_df.columns if c not in ["pair", "timestamp"]]

print(f"Features shape: {features_df.shape} (full history, RAW values)")
print(f"Feature columns ({len(feature_cols)}): {feature_cols[:5]}...")

# ============================================================================
# 3. Signal Detection
# ============================================================================

detector = sf.detector.SmaCrossSignalDetector(fast_period=10, slow_period=30)
signals = detector.run(raw_data_view)

# Filter to actionable signals only
actionable_signals = signals.value.filter(
    pl.col("signal_type").is_in(["rise", "fall"])
)

print(f"\nDetected {signals.value.height} total signals")
print(f"Actionable signals: {actionable_signals.height}")

# ============================================================================
# 4. Labeling
# ============================================================================

from signalflow.target import FixedHorizonLabeler

labeler = FixedHorizonLabeler(
    price_col="close",
    horizon=120,
    out_col="label",
    include_meta=True,
)

spot_df = raw_data_view.to_polars("spot")
labeled_full = labeler.compute(spot_df)

labeled_signals = (
    actionable_signals
    .select(["pair", "timestamp", "signal_type"])
    .join(
        labeled_full.select(["pair", "timestamp", "label"]),
        on=["pair", "timestamp"],
        how="inner",
    )
    .filter(pl.col("label").is_not_null())
)

labeled_signals = labeled_signals.with_columns(
    pl.when(pl.col("label") == "rise").then(1)
    .when(pl.col("label") == "fall").then(2)
    .otherwise(0)
    .cast(pl.Int64)
    .alias("label")
)

print(f"\nLabeled signals: {labeled_signals.height}")

# ============================================================================
# 5. Configure and Train Validator (WITH PREPROCESSOR)
# ============================================================================

input_size = len(feature_cols)

# Encoder config
encoder_params = {
    "input_size": input_size,
    "hidden_size": 64,
    "num_layers": 2,
    "dropout": 0.2,
    "bidirectional": False,
}

# Head config
head_params = {
    "hidden_sizes": [128, 64],
    "dropout": 0.3,
    "activation": "gelu",
}

# Training config
training_config = {
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "optimizer": "adamw",
    "scheduler": "reduce_on_plateau",
    "scheduler_patience": 5,
    "label_smoothing": 0.1,
}

# --- NEW: Configure Preprocessor ---
# Автоматична нормалізація (Robust) по кожній групі (парі) окремо
preprocessor = TimeSeriesPreprocessor(
    default_config=ScalerConfig(method="robust", scope="group"),
    group_col="pair",
    time_col="timestamp"
)
# -----------------------------------

# Create validator
validator = TemporalValidator(
    encoder_type="encoder/gru",
    encoder_params=encoder_params,
    head_type="head/cls/linear",
    head_params=head_params,
    
    preprocessor=preprocessor,  # <--- CONNECTED HERE
    
    window_size=96,
    window_timeframe=15,
    num_classes=3,
    class_weights=[1.0, 1.0, 1.0],
    training_config=training_config,
    feature_cols=feature_cols,
    max_epochs=100,
    batch_size=256,
    early_stopping_patience=5,  
    train_val_test_split=(0.6, 0.2, 0.2),
    split_strategy="temporal",
    num_workers=4,
)

print(f"\n{'='*60}")
print(f"Starting training")
print(f"{'='*60}")

# fit() now handles scaling internally (fit on train -> transform all)
validator.fit(
    X_train=features_df,      
    y_train=labeled_signals,    
    log_dir=Path("./logs/temporal_validator"),
    accelerator="auto",
)

print("\nTraining finished.")

# ============================================================================
# 6. Validate New Signals
# ============================================================================

# Features are automatically scaled using the saved preprocessor state
validated_signals = validator.validate_signals(signals, features_df)

validated_df = validated_signals.value.with_columns([
    pl.col("probability_none").alias("prob_neutral"),
    pl.col("probability_rise").alias("prob_rise"),
    pl.col("probability_fall").alias("prob_fall"),
])

# ============================================================================
# 7. Analyze Results
# ============================================================================

print("\n" + "=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Top Rise Signals
print("\nTop Rise Signals (high probability):")
rise_signals = (
    validated_df
    .filter(pl.col("signal_type") == "rise")
    .sort("prob_rise", descending=True)
    .select(["timestamp", "pair", "prob_rise", "prob_fall", "prob_neutral"])
    .head(10)
)
print(rise_signals)

# High-confidence signals (>70% probability)
HIGH_CONF_THRESHOLD = 0.7
high_conf_rise = validated_df.filter(
    (pl.col("signal_type") == "rise") & (pl.col("prob_rise") > HIGH_CONF_THRESHOLD)
)

print(f"\nHigh-confidence signals (>{HIGH_CONF_THRESHOLD*100:.0f}%):")
print(f"  Rise signals: {high_conf_rise.height}")

# ============================================================================
# 8. Save Validator
# ============================================================================

validator.save(Path("./best_models/temporal_validator.pkl"))
print("\nValidator saved to ./models/temporal_validator.pkl")

In [None]:
"""
Example: Training TemporalValidator with SignalFlow-NN

Flow:
1. Load raw data
2. Extract features (RAW values, normalization handled by preprocessor)
3. Detect signals
4. Label signals
5. Configure Preprocessor & Train Validator
6. Validate new signals
7. Calculate classification metrics
"""
import signalflow as sf
from signalflow.nn.validator import TemporalValidator
from signalflow.nn.model.temporal_classificator import TrainingConfig
from signalflow.nn.data.ts_preprocessor import TimeSeriesPreprocessor, ScalerConfig
from pathlib import Path
from datetime import datetime
import polars as pl
import numpy as np
import torch

# ============================================================================
# 1. Load Raw Data
# ============================================================================

raw_data = sf.data.RawDataFactory.from_duckdb_spot_store(
    spot_store_path=Path("test.duckdb"),
    pairs=["BTCUSDT", "ETHUSDT", "SOLUSDT", "BNBUSDT", "XRPUSDT"],
    start=datetime(2022, 1, 1),
    end=datetime(2025, 12, 31),
    data_types=["spot"],
)
raw_data_view = sf.core.RawDataView(raw_data)

# ============================================================================
# 2. Feature Engineering
# ============================================================================

feature_set = sf.feature.FeatureSet(extractors=[
    sf.feature.pandasta.PandasTaRsiExtractor(length=14),
    sf.feature.pandasta.PandasTaMacdExtractor(fast=12, slow=26, signal=9),
    sf.feature.pandasta.PandasTaAtrExtractor(length=14),
    sf.feature.pandasta.PandasTaBbandsExtractor(length=20, std=2.0),
])
features_df = feature_set.extract(raw_data_view)

feature_cols = [c for c in features_df.columns if c not in ["pair", "timestamp"]]

print(f"Features shape: {features_df.shape} (full history, RAW values)")
print(f"Feature columns ({len(feature_cols)}): {feature_cols[:5]}...")

# ============================================================================
# 3. Signal Detection
# ============================================================================

detector = sf.detector.SmaCrossSignalDetector(fast_period=10, slow_period=30)
signals = detector.run(raw_data_view)

actionable_signals = signals.value.filter(
    pl.col("signal_type").is_in(["rise", "fall"])
)

print(f"\nDetected {signals.value.height} total signals")
print(f"Actionable signals: {actionable_signals.height}")

# ============================================================================
# 4. Labeling
# ============================================================================

from signalflow.target import FixedHorizonLabeler

labeler = FixedHorizonLabeler(
    price_col="close",
    horizon=120,
    out_col="label",
    include_meta=True,
)

spot_df = raw_data_view.to_polars("spot")
labeled_full = labeler.compute(spot_df)

labeled_signals = (
    actionable_signals
    .select(["pair", "timestamp", "signal_type"])
    .join(
        labeled_full.select(["pair", "timestamp", "label"]),
        on=["pair", "timestamp"],
        how="inner",
    )
    .filter(pl.col("label").is_not_null())
)

labeled_signals = labeled_signals.with_columns(
    pl.when(pl.col("label") == "rise").then(1)
    .when(pl.col("label") == "fall").then(2)
    .otherwise(0)
    .cast(pl.Int64)
    .alias("label")
)

print(f"\nLabeled signals: {labeled_signals.height}")

# ============================================================================
# 5. Configure and Train Validator (WITH PREPROCESSOR)
# ============================================================================

input_size = len(feature_cols)

encoder_params = {
    "input_size": input_size,
    "hidden_size": 64,
    "num_layers": 2,
    "dropout": 0.2,
    "bidirectional": False,
}

head_params = {
    "hidden_sizes": [128, 64],
    "dropout": 0.3,
    "activation": "gelu",
}

training_config = {
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,
    "optimizer": "adamw",
    "scheduler": "reduce_on_plateau",
    "scheduler_patience": 5,
    "label_smoothing": 0.1,
}

preprocessor = TimeSeriesPreprocessor(
    default_config=ScalerConfig(method="robust", scope="group"),
    group_col="pair",
    time_col="timestamp"
)

validator = TemporalValidator(
    encoder_type="encoder/gru",
    encoder_params=encoder_params,
    head_type="head/cls/linear",
    head_params=head_params,
    
    preprocessor=preprocessor,
    
    window_size=96,
    window_timeframe=15,
    num_classes=3,
    class_weights=[1.0, 1.0, 1.0],
    training_config=training_config,
    feature_cols=feature_cols,
    max_epochs=100,
    batch_size=256,
    early_stopping_patience=5,  
    train_val_test_split=(0.6, 0.2, 0.2),
    split_strategy="temporal",
    num_workers=4,
)

print(f"\n{'='*60}")
print(f"Starting training")
print(f"{'='*60}")

validator.fit(
    X_train=features_df,      
    y_train=labeled_signals,    
    log_dir=Path("./logs/temporal_validator"),
    accelerator="auto",
)

print("\nTraining finished.")

# ============================================================================
# 6. Validate New Signals
# ============================================================================

validated_signals = validator.validate_signals(signals, features_df)

validated_df = validated_signals.value.with_columns([
    pl.col("probability_none").alias("prob_neutral"),
    pl.col("probability_rise").alias("prob_rise"),
    pl.col("probability_fall").alias("prob_fall"),
])

# ============================================================================
# 7. Classification Metrics Calculation
# ============================================================================

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    balanced_accuracy_score,
    matthews_corrcoef,
)

def calculate_classification_metrics(
    validated_df: pl.DataFrame,
    labeled_signals: pl.DataFrame,
    class_names: list[str] = ["none", "rise", "fall"],
) -> dict:
    """
    Calculate comprehensive classification metrics.
    
    Args:
        validated_df: DataFrame with predictions (prob_neutral, prob_rise, prob_fall)
        labeled_signals: DataFrame with ground truth labels
        class_names: List of class names for reporting
        
    Returns:
        Dictionary with all metrics
    """
    # Join predictions with ground truth
    eval_df = (
        validated_df
        .select(["pair", "timestamp", "prob_neutral", "prob_rise", "prob_fall"])
        .join(
            labeled_signals.select(["pair", "timestamp", "label"]),
            on=["pair", "timestamp"],
            how="inner",
        )
    )
    
    if eval_df.height == 0:
        print("Warning: No matching samples for evaluation")
        return {}
    
    # Track original count before filtering
    n_total = eval_df.height
    
    # Filter out rows with NaN in probabilities or labels
    eval_df = eval_df.filter(
        pl.col("prob_neutral").is_not_null() &
        pl.col("prob_rise").is_not_null() &
        pl.col("prob_fall").is_not_null() &
        pl.col("label").is_not_null() &
        pl.col("prob_neutral").is_not_nan() &
        pl.col("prob_rise").is_not_nan() &
        pl.col("prob_fall").is_not_nan()
    )
    
    n_valid = eval_df.height
    n_dropped = n_total - n_valid
    
    if n_dropped > 0:
        print(f"Warning: Dropped {n_dropped} samples with NaN values ({n_dropped/n_total*100:.1f}%)")
    
    if eval_df.height == 0:
        print("Warning: No valid samples after filtering NaN values")
        return {}
    
    # Get predicted class (argmax of probabilities)
    eval_df = eval_df.with_columns(
        pl.concat_list(["prob_neutral", "prob_rise", "prob_fall"])
        .list.arg_max()
        .alias("pred_class")
    )
    
    # Extract arrays
    y_true = eval_df.get_column("label").to_numpy()
    y_pred = eval_df.get_column("pred_class").to_numpy()
    
    # Probability matrix for ROC-AUC
    y_proba = eval_df.select(["prob_neutral", "prob_rise", "prob_fall"]).to_numpy()
    
    # Calculate metrics
    metrics = {}
    
    # Basic metrics
    metrics["accuracy"] = accuracy_score(y_true, y_pred)
    metrics["balanced_accuracy"] = balanced_accuracy_score(y_true, y_pred)
    metrics["mcc"] = matthews_corrcoef(y_true, y_pred)
    
    # Per-class metrics
    metrics["precision_macro"] = precision_score(y_true, y_pred, average="macro", zero_division=0)
    metrics["recall_macro"] = recall_score(y_true, y_pred, average="macro", zero_division=0)
    metrics["f1_macro"] = f1_score(y_true, y_pred, average="macro", zero_division=0)
    
    metrics["precision_weighted"] = precision_score(y_true, y_pred, average="weighted", zero_division=0)
    metrics["recall_weighted"] = recall_score(y_true, y_pred, average="weighted", zero_division=0)
    metrics["f1_weighted"] = f1_score(y_true, y_pred, average="weighted", zero_division=0)
    
    # Per-class detailed metrics
    precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    for i, name in enumerate(class_names):
        if i < len(precision_per_class):
            metrics[f"precision_{name}"] = precision_per_class[i]
            metrics[f"recall_{name}"] = recall_per_class[i]
            metrics[f"f1_{name}"] = f1_per_class[i]
    
    # ROC-AUC (one-vs-rest)
    try:
        unique_classes = np.unique(y_true)
        if len(unique_classes) > 1:
            metrics["roc_auc_ovr"] = roc_auc_score(
                y_true, y_proba, multi_class="ovr", average="macro"
            )
            metrics["roc_auc_ovo"] = roc_auc_score(
                y_true, y_proba, multi_class="ovo", average="macro"
            )
    except ValueError as e:
        print(f"ROC-AUC calculation failed: {e}")
    
    # Confusion matrix
    metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred)
    
    # Classification report (string)
    metrics["classification_report"] = classification_report(
        y_true, y_pred, target_names=class_names, zero_division=0
    )
    
    # Sample counts
    metrics["n_samples"] = n_valid
    metrics["n_dropped_nan"] = n_dropped
    metrics["nan_ratio"] = n_dropped / n_total if n_total > 0 else 0.0
    metrics["class_distribution"] = {
        int(k): int(v) for k, v in zip(*np.unique(y_true, return_counts=True))
    }
    
    return metrics
    
def print_metrics_report(metrics: dict, class_names: list[str] = ["none", "rise", "fall"]):
    """Pretty print classification metrics."""
    
    print("\n" + "=" * 70)
    print("CLASSIFICATION METRICS REPORT")
    print("=" * 70)
    
    print(f"\nSamples evaluated: {metrics.get('n_samples', 'N/A')}")
    print(f"Class distribution: {metrics.get('class_distribution', 'N/A')}")
    
    print("\n" + "-" * 40)
    print("OVERALL METRICS")
    print("-" * 40)
    print(f"  Accuracy:          {metrics.get('accuracy', 0):.4f}")
    print(f"  Balanced Accuracy: {metrics.get('balanced_accuracy', 0):.4f}")
    print(f"  MCC:               {metrics.get('mcc', 0):.4f}")
    print(f"  ROC-AUC (OvR):     {metrics.get('roc_auc_ovr', 'N/A')}")
    print(f"  ROC-AUC (OvO):     {metrics.get('roc_auc_ovo', 'N/A')}")
    
    print("\n" + "-" * 40)
    print("MACRO-AVERAGED METRICS")
    print("-" * 40)
    print(f"  Precision: {metrics.get('precision_macro', 0):.4f}")
    print(f"  Recall:    {metrics.get('recall_macro', 0):.4f}")
    print(f"  F1-Score:  {metrics.get('f1_macro', 0):.4f}")
    
    print("\n" + "-" * 40)
    print("WEIGHTED-AVERAGED METRICS")
    print("-" * 40)
    print(f"  Precision: {metrics.get('precision_weighted', 0):.4f}")
    print(f"  Recall:    {metrics.get('recall_weighted', 0):.4f}")
    print(f"  F1-Score:  {metrics.get('f1_weighted', 0):.4f}")
    
    print("\n" + "-" * 40)
    print("PER-CLASS METRICS")
    print("-" * 40)
    print(f"{'Class':<12} {'Precision':<12} {'Recall':<12} {'F1-Score':<12}")
    print("-" * 48)
    for name in class_names:
        p = metrics.get(f"precision_{name}", 0)
        r = metrics.get(f"recall_{name}", 0)
        f1 = metrics.get(f"f1_{name}", 0)
        print(f"{name:<12} {p:<12.4f} {r:<12.4f} {f1:<12.4f}")
    
    print("\n" + "-" * 40)
    print("CONFUSION MATRIX")
    print("-" * 40)
    cm = metrics.get("confusion_matrix")
    if cm is not None:
        # Header
        header = "Pred →  " + "  ".join(f"{name:>8}" for name in class_names)
        print(header)
        print("True ↓")
        for i, name in enumerate(class_names):
            row = f"{name:<8}" + "  ".join(f"{cm[i, j]:>8}" for j in range(len(class_names)))
            print(row)
    
    print("\n" + "-" * 40)
    print("FULL CLASSIFICATION REPORT")
    print("-" * 40)
    print(metrics.get("classification_report", "N/A"))


def calculate_metrics_by_signal_type(
    validated_df: pl.DataFrame,
    labeled_signals: pl.DataFrame,
) -> dict:
    """Calculate metrics separately for rise and fall signals."""
    
    results = {}
    
    for signal_type in ["rise", "fall"]:
        # Filter to specific signal type
        signal_validated = validated_df.filter(pl.col("signal_type") == signal_type)
        signal_labels = labeled_signals.filter(pl.col("signal_type") == signal_type)
        
        if signal_validated.height == 0 or signal_labels.height == 0:
            continue
            
        metrics = calculate_classification_metrics(
            signal_validated, 
            signal_labels,
            class_names=["none", "rise", "fall"]
        )
        results[signal_type] = metrics
    
    return results


# Calculate metrics
print("\n" + "=" * 70)
print("CALCULATING CLASSIFICATION METRICS")
print("=" * 70)

metrics = calculate_classification_metrics(validated_df, labeled_signals)
print_metrics_report(metrics)

# Metrics by signal type
print("\n" + "=" * 70)
print("METRICS BY SIGNAL TYPE")
print("=" * 70)

metrics_by_type = calculate_metrics_by_signal_type(validated_df, labeled_signals)

for signal_type, signal_metrics in metrics_by_type.items():
    print(f"\n{'='*40}")
    print(f"Signal Type: {signal_type.upper()}")
    print(f"{'='*40}")
    print(f"  Samples: {signal_metrics.get('n_samples', 0)}")
    print(f"  Accuracy: {signal_metrics.get('accuracy', 0):.4f}")
    print(f"  Balanced Accuracy: {signal_metrics.get('balanced_accuracy', 0):.4f}")
    print(f"  F1 (macro): {signal_metrics.get('f1_macro', 0):.4f}")

# ============================================================================
# 8. Analyze Results
# ============================================================================

print("\n" + "=" * 60)
print("VALIDATION RESULTS")
print("=" * 60)

# Top Rise Signals
print("\nTop Rise Signals (high probability):")
rise_signals = (
    validated_df
    .filter(pl.col("signal_type") == "rise")
    .sort("prob_rise", descending=True)
    .select(["timestamp", "pair", "prob_rise", "prob_fall", "prob_neutral"])
    .head(10)
)
print(rise_signals)

# High-confidence signals
HIGH_CONF_THRESHOLD = 0.7
high_conf_rise = validated_df.filter(
    (pl.col("signal_type") == "rise") & (pl.col("prob_rise") > HIGH_CONF_THRESHOLD)
)
high_conf_fall = validated_df.filter(
    (pl.col("signal_type") == "fall") & (pl.col("prob_fall") > HIGH_CONF_THRESHOLD)
)

print(f"\nHigh-confidence signals (>{HIGH_CONF_THRESHOLD*100:.0f}%):")
print(f"  Rise signals: {high_conf_rise.height}")
print(f"  Fall signals: {high_conf_fall.height}")

# ============================================================================
# 9. Save Validator and Metrics
# ============================================================================

validator.save(Path("./best_models/temporal_validator.pkl"))
print("\nValidator saved to ./models/temporal_validator.pkl")

# Save metrics to JSON
import json

metrics_to_save = {k: v for k, v in metrics.items() 
                   if k not in ["confusion_matrix", "classification_report"]}
metrics_to_save["confusion_matrix"] = metrics["confusion_matrix"].tolist()

# Fix: convert numpy int64 keys to regular int
if "class_distribution" in metrics_to_save:
    metrics_to_save["class_distribution"] = {
        int(k): int(v) for k, v in metrics_to_save["class_distribution"].items()
    }

with open("./best_models/classification_metrics.json", "w") as f:
    json.dump(metrics_to_save, f, indent=2, default=str)

print("Metrics saved to ./best_models/classification_metrics.json")