In [None]:
import os
import pandas as pd
import torch
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix
from itertools import product
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Mount Google Drive (skip if already mounted)
from google.colab import drive
drive.mount('/content/drive')
# Add custom KAN paths
import sys
sys.path.append('/content/drive/MyDrive/packages/efficient_kan')
sys.path.append('/content/drive/MyDrive/packages/fastkan')

# Import KAN and FastKAN classes
from kan import *
from fastkan import *

In [None]:
"""
Kolmogorov-Arnold Networks (KAN) for Crayfish Sex Classification
================================================================

Original Paper Authors:
    Yasin Atilkan¹, Berk Kirik², Eren Tuna Acikbas³, Fatih Ekinci⁴,
    Koray Acici¹'⁴, Tunc Asuroglu⁵'⁶*, Recep Benzer⁷,
    Mehmet Serdar Guzel⁴'⁸, Semra Benzer⁹

Affiliations:
    ¹ Department of Artificial Intelligence and Data Engineering, Ankara University, Turkey
    ² Department of Biomedical Engineering, Ankara University, Turkey
    ³ Department of Petroleum and Natural Gas Engineering, Middle East Technical University, Turkey
    ⁴ Institute of Artificial Intelligence, Ankara University, Turkey
    ⁵ Faculty of Medicine and Health Technology, Tampere University, Finland
    ⁶ VTT Technical Research Centre of Finland, Finland
    ⁷ Department of Management Information System, Ankara Medipol University, Turkey
    ⁸ Department of Computer Engineering, Ankara University, Turkey
    ⁹ Department of Science Education, Gazi University, Turkey
    * Corresponding author

Implementation follows the methodology from:
"Enhancing Crayfish Sex Identification with Kolmogorov-Arnold Networks
and Stacked Autoencoders" (Atilkan et al., 2025)

Key Features:
- Grid search for optimal hyperparameter selection
- Pre-split train/test data from Stacked Autoencoder
- Per-split data standardization (prevents data leakage)
- Early stopping mechanism for overfitting mitigation
- Comprehensive evaluation metrics
- Best model identification and storage

Implementation Date: 2025
"""


import os
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Any
import warnings
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score,
    recall_score, f1_score, matthews_corrcoef, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product
from tqdm import tqdm

warnings.filterwarnings('ignore')


class Config:
    """Configuration class for KAN model training and evaluation."""

    # Data paths - Stacked Autoencoder extracted features
    TRAIN_FEATURES_PATH = 'train_features_stackedautoencoder.npy'
    TRAIN_LABELS_PATH = 'train_labels_stackedautoencoder.npy'
    TEST_FEATURES_PATH = 'test_features_stackedautoencoder.npy'
    TEST_LABELS_PATH = 'test_labels_stackedautoencoder.npy'
    
    OUTPUT_PATH = '/content/drive/MyDrive/crawfish_results/ml_results/kan_grid_search_sae'

    GRID_SEARCH_PARAMS = {
        'layers_hidden': [
            [18432, 1024, 512, 256, 128, 64, 32, 1],
            [18432, 512, 256, 128, 64, 32, 1],
            [18432, 256, 128, 64, 32, 1],
            [18432, 64, 32, 1]
        ],
        'grid_size': [8, 7, 6, 5, 4],
        'learning_rate': [0.0005, 0.005, 0.001, 0.0001],
        'spline_order': [5, 4, 3],
        'scale_base': [1.0, 2.0, 3.0],
        'scale_spline': [1.0, 2.0],
        'batch_size': [128, 64, 32],
        'optimizer': ['Adam', 'SGD']
    }

    NUM_EPOCHS = 50
    EARLY_STOPPING_PATIENCE = 10
    EARLY_STOPPING_DELTA = 0.0

    FIG_SIZE = (10, 8)
    DPI = 300
    CLASS_NAMES = ['Female (D)', 'Male (E)']

    @classmethod
    def display_config(cls, logger: logging.Logger) -> None:
        """Display current configuration."""
        logger.info("Configuration Summary:")
        logger.info(f"  Total hyperparameter combinations: {cls._get_total_combinations()}")
        logger.info(f"  Learning rates: {cls.GRID_SEARCH_PARAMS['learning_rate']}")
        logger.info(f"  Maximum epochs: {cls.NUM_EPOCHS}")
        logger.info(f"  Early stopping patience: {cls.EARLY_STOPPING_PATIENCE} epochs")
        logger.info(f"  Early stopping delta: {cls.EARLY_STOPPING_DELTA}")

    @classmethod
    def _get_total_combinations(cls) -> int:
        """Calculate total number of hyperparameter combinations."""
        return int(np.prod([len(v) for v in cls.GRID_SEARCH_PARAMS.values()]))


def setup_logging(output_path: str) -> logging.Logger:
    """Setup comprehensive logging system with file and console handlers."""
    Path(output_path).mkdir(parents=True, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(output_path, f'kan_training_{timestamp}.log')

    logger = logging.getLogger('KAN_Training')
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    logger.propagate = False

    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter(
        '%(asctime)s | %(levelname)-8s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    file_handler.setFormatter(file_formatter)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_formatter = logging.Formatter('%(message)s')
    console_handler.setFormatter(console_formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    logger.info("=" * 80)
    logger.info("KAN Training for Crayfish Sex Classification (Stacked Autoencoder)")
    logger.info("Based on: Atilkan et al. (2025)")
    logger.info("=" * 80)
    logger.info(f"Session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f"Log file: {log_file}")
    logger.info("=" * 80)

    return logger


class EarlyStopping:
    """Early stopping mechanism to prevent overfitting."""

    def __init__(self, patience: int = 10, verbose: bool = False, delta: float = 0.0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.delta = delta
        self.best_model_state = None

    def __call__(self, val_loss: float, model: torch.nn.Module) -> None:
        """Check if early stopping criteria is met."""
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                model.load_state_dict(self.best_model_state)
        else:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
            self.counter = 0

    def reset(self) -> None:
        """Reset early stopping state."""
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model_state = None


def calculate_specificity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Calculate specificity (true negative rate)."""
    cm = confusion_matrix(y_true, y_pred)
    if cm.size == 1:
        return 1.0
    tn, fp, fn, tp = cm.ravel()
    return tn / (tn + fp) if (tn + fp) > 0 else 0.0


def calculate_all_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Calculate all evaluation metrics."""
    return {
        'Accuracy': accuracy_score(y_true, y_pred),
        'Precision': precision_score(y_true, y_pred, zero_division=0),
        'Recall': recall_score(y_true, y_pred, zero_division=0),
        'Specificity': calculate_specificity(y_true, y_pred),
        'F1-Score': f1_score(y_true, y_pred, zero_division=0),
        'MCC': matthews_corrcoef(y_true, y_pred)
    }


def save_confusion_matrix(
    cm: np.ndarray,
    filename: str,
    class_names: List[str],
    title: str = 'Confusion Matrix'
) -> None:
    """Save confusion matrix as high-quality PNG image."""
    plt.figure(figsize=Config.FIG_SIZE)

    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names,
        cbar=True, square=True, linewidths=1, linecolor='gray'
    )

    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('True Label', fontsize=14, fontweight='bold')
    plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(filename, dpi=Config.DPI, bbox_inches='tight')
    plt.close()


def print_section_header(title: str, logger: logging.Logger, level: int = 1) -> None:
    """Print professional section header."""
    logger.info("")

    if level == 1:
        logger.info("=" * 80)
        logger.info(title)
        logger.info("=" * 80)
    else:
        logger.info("-" * 80)
        logger.info(title)
        logger.info("-" * 80)


def format_metrics_table(metrics: Dict[str, float]) -> str:
    """Format metrics dictionary as readable table."""
    lines = []
    lines.append("-" * 40)
    for metric, value in metrics.items():
        lines.append(f"  {metric:<20} {value:>8.4f}")
    lines.append("-" * 40)
    return "\n".join(lines)


class DataManager:
    """Handle data loading and preprocessing operations."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def load_data(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        """Load pre-split dataset from Stacked Autoencoder .npy files."""
        self.logger.info("Loading pre-split data from Stacked Autoencoder...")

        try:
            train_features = np.load(self.config.TRAIN_FEATURES_PATH)
            train_labels = np.load(self.config.TRAIN_LABELS_PATH)
            test_features = np.load(self.config.TEST_FEATURES_PATH)
            test_labels = np.load(self.config.TEST_LABELS_PATH)
            
            self.logger.info(f"Successfully loaded data:")
            self.logger.info(f"  Train features shape: {train_features.shape}")
            self.logger.info(f"  Train labels shape: {train_labels.shape}")
            self.logger.info(f"  Test features shape: {test_features.shape}")
            self.logger.info(f"  Test labels shape: {test_labels.shape}")
            
        except Exception as e:
            self.logger.error(f"Failed to load data: {str(e)}")
            raise

        self._log_dataset_info(train_labels, test_labels)
        return train_features, train_labels, test_features, test_labels

    def _log_dataset_info(self, train_labels: np.ndarray, test_labels: np.ndarray) -> None:
        """Log dataset statistics."""
        self.logger.info("")
        self.logger.info("Dataset Statistics:")
        self.logger.info(f"  Total train samples: {len(train_labels)}")
        self.logger.info(f"  Total test samples: {len(test_labels)}")
        
        # Train class distribution
        unique_train, counts_train = np.unique(train_labels, return_counts=True)
        self.logger.info("  Train class distribution:")
        for cls, count in zip(unique_train, counts_train):
            percentage = count / len(train_labels) * 100
            self.logger.info(f"    Class {cls}: {count:3d} samples ({percentage:5.1f}%)")
        
        # Test class distribution
        unique_test, counts_test = np.unique(test_labels, return_counts=True)
        self.logger.info("  Test class distribution:")
        for cls, count in zip(unique_test, counts_test):
            percentage = count / len(test_labels) * 100
            self.logger.info(f"    Class {cls}: {count:3d} samples ({percentage:5.1f}%)")


class KANTrainer:
    """KAN model training with comprehensive grid search."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def train_model(
        self,
        model: torch.nn.Module,
        train_loader: TorchDataLoader,
        val_loader: TorchDataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.Module,
        verbose: bool = False
    ) -> Dict[str, Any]:
        """Train model with early stopping."""
        early_stopping = EarlyStopping(
            patience=self.config.EARLY_STOPPING_PATIENCE,
            verbose=verbose,
            delta=self.config.EARLY_STOPPING_DELTA
        )

        history = {
            'train_loss': [],
            'val_loss': [],
            'epochs_trained': 0,
            'early_stopped': False
        }

        for epoch in range(self.config.NUM_EPOCHS):
            model.train()
            train_loss = 0.0
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                outputs = model(X_batch).squeeze()
                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)
            history['train_loss'].append(avg_train_loss)

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for X_val, y_val in val_loader:
                    val_outputs = model(X_val).squeeze()
                    loss = criterion(val_outputs, y_val)
                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            history['val_loss'].append(avg_val_loss)

            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                history['epochs_trained'] = epoch + 1
                history['early_stopped'] = True
                break
        else:
            history['epochs_trained'] = self.config.NUM_EPOCHS

        return history

    def evaluate_model(
        self,
        model: torch.nn.Module,
        X_tensor: torch.Tensor,
        y_tensor: torch.Tensor
    ) -> Tuple[np.ndarray, np.ndarray, Dict[str, float]]:
        """Evaluate model and calculate metrics."""
        model.eval()
        with torch.no_grad():
            outputs = model(X_tensor).squeeze()
            probabilities = torch.sigmoid(outputs).cpu().numpy()
            predictions = (probabilities > 0.5).astype(int)
            actuals = y_tensor.cpu().numpy()

        metrics = calculate_all_metrics(actuals, predictions)
        return predictions, probabilities, metrics

    def grid_search(
        self,
        train_features: np.ndarray,
        train_labels: np.ndarray,
        test_features: np.ndarray,
        test_labels: np.ndarray
    ) -> Dict[str, Any]:
        """Perform grid search for optimal hyperparameters using pre-split data."""
        print_section_header("HYPERPARAMETER OPTIMIZATION VIA GRID SEARCH", self.logger)

        param_combinations = list(product(
            self.config.GRID_SEARCH_PARAMS['layers_hidden'],
            self.config.GRID_SEARCH_PARAMS['grid_size'],
            self.config.GRID_SEARCH_PARAMS['learning_rate'],
            self.config.GRID_SEARCH_PARAMS['spline_order'],
            self.config.GRID_SEARCH_PARAMS['scale_base'],
            self.config.GRID_SEARCH_PARAMS['scale_spline'],
            self.config.GRID_SEARCH_PARAMS['batch_size'],
            self.config.GRID_SEARCH_PARAMS['optimizer']
        ))

        total_configs = len(param_combinations)
        self.logger.info(f"Total hyperparameter configurations: {total_configs}")
        self.logger.info("")

        # Standardize features
        scaler = StandardScaler()
        train_features_scaled = scaler.fit_transform(train_features)
        test_features_scaled = scaler.transform(test_features)

        # Convert to tensors
        X_train_tensor = torch.tensor(train_features_scaled, dtype=torch.float32)
        y_train_tensor = torch.tensor(train_labels, dtype=torch.float32)
        X_test_tensor = torch.tensor(test_features_scaled, dtype=torch.float32)
        y_test_tensor = torch.tensor(test_labels, dtype=torch.float32)

        grid_results = []
        best_accuracy = 0.0
        best_params = None
        best_config_id = None
        best_scaler = None

        start_time = datetime.now()

        pbar = tqdm(
            enumerate(param_combinations, 1),
            total=total_configs,
            desc="Grid Search Progress",
            ncols=120,
            bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
        )

        for config_id, (layers_hidden, grid_size, learning_rate, spline_order,
                       scale_base, scale_spline, batch_size, optimizer_name) in pbar:

            self.logger.debug("")
            self.logger.debug(f"Configuration ID: {config_id}/{total_configs}")
            self.logger.debug(f"  Architecture: {layers_hidden}")
            self.logger.debug(f"  Grid size: {grid_size}, Spline order: {spline_order}")
            self.logger.debug(f"  Learning rate: {learning_rate}")
            self.logger.debug(f"  Scale base: {scale_base}, Scale spline: {scale_spline}")
            self.logger.debug(f"  Batch size: {batch_size}, Optimizer: {optimizer_name}")

            # Create data loaders
            train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
            train_loader = TorchDataLoader(train_dataset, batch_size=batch_size, shuffle=True)

            test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
            test_loader = TorchDataLoader(test_dataset, batch_size=batch_size, shuffle=False)

            try:
                model = KAN(
                    layers_hidden=list(layers_hidden),
                    grid_size=grid_size,
                    spline_order=spline_order,
                    scale_base=scale_base,
                    scale_spline=scale_spline
                )
            except NameError:
                self.logger.error("KAN class not found. Please ensure KAN is properly imported.")
                raise

            if optimizer_name == 'Adam':
                optimizer = optim.Adam(model.parameters(), lr=learning_rate)
            else:
                optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

            criterion = torch.nn.BCEWithLogitsLoss()

            history = self.train_model(
                model=model,
                train_loader=train_loader,
                val_loader=test_loader,
                optimizer=optimizer,
                criterion=criterion,
                verbose=False
            )

            _, _, metrics = self.evaluate_model(model, X_test_tensor, y_test_tensor)
            accuracy = metrics['Accuracy']

            grid_results.append({
                'configuration_id': config_id,
                'architecture': str(layers_hidden),
                'grid_size': grid_size,
                'learning_rate': learning_rate,
                'spline_order': spline_order,
                'scale_base': scale_base,
                'scale_spline': scale_spline,
                'batch_size': batch_size,
                'optimizer': optimizer_name,
                'accuracy': accuracy,
                'precision': metrics['Precision'],
                'recall': metrics['Recall'],
                'f1_score': metrics['F1-Score'],
                'mcc': metrics['MCC']
            })

            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_config_id = config_id
                best_params = {
                    'layers_hidden': list(layers_hidden),
                    'grid_size': grid_size,
                    'learning_rate': learning_rate,
                    'spline_order': spline_order,
                    'scale_base': scale_base,
                    'scale_spline': scale_spline,
                    'batch_size': batch_size,
                    'optimizer': optimizer_name
                }
                best_scaler = scaler

                self.logger.info("")
                self.logger.info("-" * 80)
                self.logger.info("NEW BEST CONFIGURATION FOUND")
                self.logger.info("-" * 80)
                self.logger.info(f"  Configuration ID: {config_id}/{total_configs}")
                self.logger.info(f"  Test Accuracy: {accuracy:.4f}")
                self.logger.info(f"  Learning Rate: {learning_rate}")
                self.logger.info("-" * 80)

            elapsed_time = (datetime.now() - start_time).total_seconds()
            avg_time_per_config = elapsed_time / config_id
            remaining_configs = total_configs - config_id
            eta_seconds = avg_time_per_config * remaining_configs
            eta_str = str(timedelta(seconds=int(eta_seconds)))

            pbar.set_postfix_str(
                f"Best: {best_accuracy:.4f} | Current: {accuracy:.4f} | ETA: {eta_str}",
                refresh=True
            )

        pbar.close()

        total_time = (datetime.now() - start_time).total_seconds()
        total_time_str = str(timedelta(seconds=int(total_time)))

        print_section_header("GRID SEARCH COMPLETED", self.logger)
        self.logger.info("")
        self.logger.info("Summary of Grid Search Results:")
        self.logger.info("-" * 80)
        self.logger.info(f"Total configurations evaluated: {total_configs}")
        self.logger.info(f"Total training time: {total_time_str}")
        self.logger.info(f"Average time per configuration: {total_time/total_configs:.1f}s")
        self.logger.info("")
        self.logger.info(f"Best configuration ID: #{best_config_id}")
        self.logger.info(f"Best test accuracy: {best_accuracy:.4f}")
        self.logger.info("-" * 80)
        self.logger.info("")
        self.logger.info("Optimal Hyperparameters:")
        self.logger.info("-" * 80)
        for param_name, param_value in best_params.items():
            self.logger.info(f"  {param_name:22s}: {param_value}")
        self.logger.info("-" * 80)

        self.logger.info("")
        self.logger.info("Top 5 Configurations by Accuracy:")
        self.logger.info("-" * 80)
        sorted_results = sorted(grid_results, key=lambda x: x['accuracy'], reverse=True)
        for rank, result in enumerate(sorted_results[:5], 1):
            self.logger.info(
                f"  {rank}. Config #{result['configuration_id']:4d} | "
                f"Accuracy: {result['accuracy']:.4f} | "
                f"LR: {result['learning_rate']} | {result['optimizer']:4s}"
            )
        self.logger.info("-" * 80)

        return {
            'best_params': best_params,
            'best_accuracy': best_accuracy,
            'best_configuration_id': best_config_id,
            'grid_results': grid_results,
            'total_training_time_seconds': total_time,
            'scaler': best_scaler
        }

    def train_final_model(
        self,
        train_features: np.ndarray,
        train_labels: np.ndarray,
        test_features: np.ndarray,
        test_labels: np.ndarray,
        best_params: Dict[str, Any],
        scaler: StandardScaler
    ) -> Dict[str, Any]:
        """Train final model with optimal hyperparameters."""
        print_section_header("TRAINING FINAL MODEL WITH OPTIMAL HYPERPARAMETERS", self.logger)

        # Standardize features
        train_features_scaled = scaler.fit_transform(train_features)
        test_features_scaled = scaler.transform(test_features)

        # Convert to tensors
        X_train_tensor = torch.tensor(train_features_scaled, dtype=torch.float32)
        y_train_tensor = torch.tensor(train_labels, dtype=torch.float32)
        X_test_tensor = torch.tensor(test_features_scaled, dtype=torch.float32)
        y_test_tensor = torch.tensor(test_labels, dtype=torch.float32)

        # Create data loaders
        train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
        train_loader = TorchDataLoader(
            train_dataset,
            batch_size=best_params['batch_size'],
            shuffle=True
        )

        test_dataset = TensorDataset(X_test_tensor, y_test_tensor)
        test_loader = TorchDataLoader(
            test_dataset,
            batch_size=best_params['batch_size'],
            shuffle=False
        )

        # Create model
        model = KAN(
            layers_hidden=best_params['layers_hidden'],
            grid_size=best_params['grid_size'],
            spline_order=best_params['spline_order'],
            scale_base=best_params['scale_base'],
            scale_spline=best_params['scale_spline']
        )

        if best_params['optimizer'] == 'Adam':
            optimizer = optim.Adam(model.parameters(), lr=best_params['learning_rate'])
        else:
            optimizer = optim.SGD(model.parameters(), lr=best_params['learning_rate'], momentum=0.9)

        criterion = torch.nn.BCEWithLogitsLoss()

        # Train model
        self.logger.info("Training final model...")
        history = self.train_model(
            model, train_loader, test_loader,
            optimizer, criterion, verbose=True
        )

        # Evaluate on test set
        pred_test, prob_test, metrics_test = self.evaluate_model(
            model, X_test_tensor, y_test_tensor
        )

        cm_test = confusion_matrix(y_test_tensor.cpu().numpy(), pred_test)

        # Classification report
        y_test_labels = ['D' if x == 0 else 'E' for x in y_test_tensor.cpu().numpy()]
        pred_test_labels = ['D' if x == 0 else 'E' for x in pred_test]
        class_report = classification_report(y_test_labels, pred_test_labels, target_names=['D', 'E'])

        self.logger.info("")
        self.logger.info("Final Model Test Results:")
        self.logger.info("-" * 80)
        self.logger.info(f"Accuracy:     {metrics_test['Accuracy']:.4f}")
        self.logger.info(f"Precision:    {metrics_test['Precision']:.4f}")
        self.logger.info(f"Recall:       {metrics_test['Recall']:.4f}")
        self.logger.info(f"Specificity:  {metrics_test['Specificity']:.4f}")
        self.logger.info(f"F1-Score:     {metrics_test['F1-Score']:.4f}")
        self.logger.info(f"MCC:          {metrics_test['MCC']:.4f}")
        self.logger.info("-" * 80)
        self.logger.info(f"\nClassification Report:\n{class_report}")
        self.logger.info(f"Confusion Matrix:\n{cm_test}")

        # Save model and results
        best_model_dir = os.path.join(self.config.OUTPUT_PATH, 'best_model')
        Path(best_model_dir).mkdir(parents=True, exist_ok=True)

        best_model_path = os.path.join(best_model_dir, 'best_model.pth')
        best_scaler_path = os.path.join(best_model_dir, 'best_scaler.pkl')
        best_params_path = os.path.join(best_model_dir, 'best_hyperparameters.pkl')

        torch.save(model.state_dict(), best_model_path)
        with open(best_scaler_path, 'wb') as f:
            pickle.dump(scaler, f)
        with open(best_params_path, 'wb') as f:
            pickle.dump(best_params, f)

        self.logger.info(f"\nBest model saved to: {best_model_path}")
        self.logger.info(f"Best scaler saved to: {best_scaler_path}")
        self.logger.info(f"Best hyperparameters saved to: {best_params_path}")

        # Save predictions
        predictions_df = pd.DataFrame({
            'Actual': test_labels,
            'Predicted': pred_test,
            'Probability': prob_test
        })
        predictions_file = os.path.join(best_model_dir, 'test_predictions.xlsx')
        predictions_df.to_excel(predictions_file, index=False)
        self.logger.info(f"Predictions saved to: {predictions_file}")

        # Save confusion matrix
        cm_file = os.path.join(best_model_dir, 'confusion_matrix.png')
        save_confusion_matrix(
            cm_test, cm_file, self.config.CLASS_NAMES,
            'Best Model - Test Set Confusion Matrix'
        )
        self.logger.info(f"Confusion matrix saved to: {cm_file}")

        # Save detailed metrics
        metrics_file = os.path.join(best_model_dir, 'final_metrics.txt')
        with open(metrics_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("FINAL MODEL METRICS\n")
            f.write("=" * 80 + "\n\n")
            f.write("Optimal Hyperparameters:\n")
            f.write("-" * 80 + "\n")
            for param_name, param_value in best_params.items():
                f.write(f"  {param_name:22s}: {param_value}\n")
            f.write("-" * 80 + "\n\n")
            f.write("Test Set Metrics:\n")
            f.write("-" * 80 + "\n")
            f.write(f"Accuracy:     {metrics_test['Accuracy']:.16f}\n")
            f.write(f"Precision:    {metrics_test['Precision']:.16f}\n")
            f.write(f"Recall:       {metrics_test['Recall']:.16f}\n")
            f.write(f"Specificity:  {metrics_test['Specificity']:.16f}\n")
            f.write(f"F1-Score:     {metrics_test['F1-Score']:.16f}\n")
            f.write(f"MCC:          {metrics_test['MCC']:.16f}\n")
            f.write("-" * 80 + "\n\n")
            f.write(f"Classification Report:\n{class_report}\n")
            f.write(f"\nConfusion Matrix:\n{cm_test}\n")

        self.logger.info(f"Detailed metrics saved to: {metrics_file}")

        return {
            'model': model,
            'scaler': scaler,
            'predictions': pred_test,
            'probabilities': prob_test,
            'confusion_matrix': cm_test,
            'metrics': metrics_test,
            'classification_report': class_report,
            'history': history
        }


class ResultsManager:
    """Generate comprehensive reports and visualizations."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def save_results(
        self,
        grid_search_results: Dict,
        training_results: Dict
    ) -> None:
        """Save all results and generate visualizations."""
        print_section_header("SAVING RESULTS AND GENERATING REPORTS", self.logger)

        output_excel = os.path.join(self.config.OUTPUT_PATH, 'kan_comprehensive_results.xlsx')

        with pd.ExcelWriter(output_excel, engine='openpyxl') as writer:
            # Grid search results
            grid_df = pd.DataFrame(grid_search_results['grid_results'])
            grid_df = grid_df.sort_values('accuracy', ascending=False)
            grid_df.to_excel(writer, sheet_name='Grid_Search_Results', index=False)

            # Best parameters
            best_params_df = pd.DataFrame([grid_search_results['best_params']])
            best_params_df.to_excel(writer, sheet_name='Optimal_Parameters', index=False)

            # Final metrics
            final_metrics_df = pd.DataFrame([training_results['metrics']])
            final_metrics_df.to_excel(writer, sheet_name='Final_Metrics', index=False)

        self.logger.info(f"Comprehensive results saved to: {output_excel}")
        self.logger.info("All visualizations generated successfully")


def main():
    """Main execution pipeline for KAN training."""

    logger = setup_logging(Config.OUTPUT_PATH)
    Config.display_config(logger)

    try:
        # Load data
        data_manager = DataManager(Config, logger)
        train_features, train_labels, test_features, test_labels = data_manager.load_data()

        # Initialize trainer
        trainer = KANTrainer(Config, logger)

        # Grid search
        grid_search_results = trainer.grid_search(
            train_features, train_labels,
            test_features, test_labels
        )

        # Save grid search results
        grid_file = os.path.join(Config.OUTPUT_PATH, 'grid_search_results.xlsx')
        grid_df = pd.DataFrame(grid_search_results['grid_results'])
        grid_df = grid_df.sort_values('accuracy', ascending=False)
        grid_df.to_excel(grid_file, index=False)
        logger.info(f"\nGrid search results saved to: {grid_file}")

        # Train final model
        training_results = trainer.train_final_model(
            train_features, train_labels,
            test_features, test_labels,
            grid_search_results['best_params'],
            grid_search_results['scaler']
        )

        # Save comprehensive results
        results_manager = ResultsManager(Config, logger)
        results_manager.save_results(grid_search_results, training_results)

        print_section_header("TRAINING COMPLETED SUCCESSFULLY", logger)
        logger.info(f"\nAll outputs saved to: {Config.OUTPUT_PATH}")
        logger.info(f"Best test accuracy: {grid_search_results['best_accuracy']:.4f}")

    except Exception as e:
        logger.error(f"Training failed: {str(e)}", exc_info=True)
        raise

    finally:
        logger.info("")
        logger.info("=" * 80)
        logger.info(f"Session ended: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        logger.info("=" * 80)


if __name__ == "__main__":
    main()