In [None]:
import os
import pandas as pd
import torch
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import confusion_matrix
from itertools import product
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# Mount Google Drive (skip if already mounted)
from google.colab import drive
drive.mount('/content/drive')
# Add custom KAN paths
import sys
sys.path.append('/content/drive/MyDrive/packages/efficient_kan')
sys.path.append('/content/drive/MyDrive/packages/fastkan')

# Import KAN and FastKAN classes
from kan import *
from fastkan import *



In [None]:
"""
Kolmogorov-Arnold Networks (KAN) for Crayfish Sex Classification
================================================================

Original Paper Authors:
    Yasin Atilkan¹, Berk Kirik², Eren Tuna Acikbas³, Fatih Ekinci⁴,
    Koray Acici¹'⁴, Tunc Asuroglu⁵'⁶*, Recep Benzer⁷,
    Mehmet Serdar Guzel⁴'⁸, Semra Benzer⁹

Affiliations:
    ¹ Department of Artificial Intelligence and Data Engineering, Ankara University, Turkey
    ² Department of Biomedical Engineering, Ankara University, Turkey
    ³ Department of Petroleum and Natural Gas Engineering, Middle East Technical University, Turkey
    ⁴ Institute of Artificial Intelligence, Ankara University, Turkey
    ⁵ Faculty of Medicine and Health Technology, Tampere University, Finland
    ⁶ VTT Technical Research Centre of Finland, Finland
    ⁷ Department of Management Information System, Ankara Medipol University, Turkey
    ⁸ Department of Computer Engineering, Ankara University, Turkey
    ⁹ Department of Science Education, Gazi University, Turkey
    * Corresponding author

Implementation follows the methodology from:
"Enhancing Crayfish Sex Identification with Kolmogorov-Arnold Networks
and Stacked Autoencoders" (Atilkan et al., 2024)

Key Features:
- Grid search for optimal hyperparameter selection
- 10-Fold Cross-Validation with stratified sampling
- Per-fold data standardization (prevents data leakage)
- Early stopping mechanism for overfitting mitigation
- Comprehensive evaluation on both test folds and full dataset
- Statistical analysis and visualization
- Best model identification and storage

Implementation Date: 2024
License: MIT
"""


import os
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Any
import warnings
from pathlib import Path
import pickle

import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import DataLoader as TorchDataLoader, TensorDataset
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import (
    confusion_matrix, accuracy_score, precision_score,
    recall_score, f1_score, matthews_corrcoef, classification_report
)
import matplotlib.pyplot as plt
import seaborn as sns
from itertools import product
from tqdm import tqdm

warnings.filterwarnings('ignore')


class Config:
    """Configuration class for KAN model training and evaluation."""

    DATA_PATH = '/content/drive/MyDrive/crawfish_data/kerevit_knn_final.xlsx'
    OUTPUT_PATH = '/content/drive/MyDrive/crawfish_results/ml_results/kan_grid_search'

    FEATURES = ['W', 'KB', 'KE', 'AB', 'AE', 'K_Sag', 'K_Sol', 'U_Sag', 'U_Sol', 'KE_Sag', 'KE_Sol']
    TARGET = 'CINSIYET'

    GRID_SEARCH_PARAMS = {
        'layers_hidden': [
            [11, 256, 128, 64, 32, 1],
            [11, 128, 64, 32, 1],
            [11, 32, 1],
            [11, 128, 64, 32, 16, 1]
        ],
        'grid_size': [8, 7, 6, 5, 4],
        'spline_order': [5, 4, 3],
        'scale_base': [1.0, 2.0, 3.0],
        'scale_spline': [1.0, 2.0],
        'batch_size': [128, 64, 32],
        'optimizer': ['Adam', 'SGD']
    }

    LEARNING_RATE = 0.0005
    NUM_EPOCHS = 50
    N_FOLDS = 10
    RANDOM_STATE = 42

    EARLY_STOPPING_PATIENCE = 10
    EARLY_STOPPING_DELTA = 0.0

    FIG_SIZE = (10, 8)
    DPI = 300
    CLASS_NAMES = ['Female (D)', 'Male (E)']

    @classmethod
    def display_config(cls, logger: logging.Logger) -> None:
        """Display current configuration."""
        logger.info("Configuration Summary:")
        logger.info(f"  Total hyperparameter combinations: {cls._get_total_combinations()}")
        logger.info(f"  Learning rate (fixed): {cls.LEARNING_RATE}")
        logger.info(f"  Maximum epochs: {cls.NUM_EPOCHS}")
        logger.info(f"  Cross-validation folds: {cls.N_FOLDS}")
        logger.info(f"  Early stopping patience: {cls.EARLY_STOPPING_PATIENCE} epochs")
        logger.info(f"  Early stopping delta: {cls.EARLY_STOPPING_DELTA}")

    @classmethod
    def _get_total_combinations(cls) -> int:
        """Calculate total number of hyperparameter combinations."""
        return int(np.prod([len(v) for v in cls.GRID_SEARCH_PARAMS.values()]))


def setup_logging(output_path: str) -> logging.Logger:
    """Setup comprehensive logging system with file and console handlers."""
    Path(output_path).mkdir(parents=True, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    log_file = os.path.join(output_path, f'kan_training_{timestamp}.log')

    logger = logging.getLogger('KAN_Training')
    logger.setLevel(logging.DEBUG)
    logger.handlers = []
    logger.propagate = False

    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter(
        '%(asctime)s | %(levelname)-8s | %(message)s',
        datefmt='%Y-%m-%d %H:%M:%S'
    )
    file_handler.setFormatter(file_formatter)

    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_formatter = logging.Formatter('%(message)s')
    console_handler.setFormatter(console_formatter)

    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    logger.info("=" * 80)
    logger.info("KAN Training for Crayfish Sex Classification")
    logger.info("Based on: Atilkan et al. (2024)")
    logger.info("=" * 80)
    logger.info(f"Session started: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    logger.info(f"Log file: {log_file}")
    logger.info("=" * 80)

    return logger


class EarlyStopping:
    """Early stopping mechanism to prevent overfitting."""

    def __init__(self, patience: int = 10, verbose: bool = False, delta: float = 0.0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.delta = delta
        self.best_model_state = None

    def __call__(self, val_loss: float, model: torch.nn.Module) -> None:
        """Check if early stopping criteria is met."""
        if self.best_loss is None:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
        elif val_loss > self.best_loss - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                model.load_state_dict(self.best_model_state)
        else:
            self.best_loss = val_loss
            self.best_model_state = model.state_dict()
            self.counter = 0

    def reset(self) -> None:
        """Reset early stopping state."""
        self.counter = 0
        self.best_loss = None
        self.early_stop = False
        self.best_model_state = None


def calculate_specificity(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Calculate specificity (true negative rate)."""
    cm = confusion_matrix(y_true, y_pred)
    if cm.size == 1:
        return 1.0
    tn, fp, fn, tp = cm.ravel()
    return tn / (tn + fp) if (tn + fp) > 0 else 0.0


def calculate_all_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Calculate all evaluation metrics."""
    return {
        'Accuracy': accuracy_score(y_true, y_pred),
        'Precision': precision_score(y_true, y_pred, zero_division=0),
        'Recall': recall_score(y_true, y_pred, zero_division=0),
        'Specificity': calculate_specificity(y_true, y_pred),
        'F1-Score': f1_score(y_true, y_pred, zero_division=0),
        'MCC': matthews_corrcoef(y_true, y_pred)
    }


def save_confusion_matrix(
    cm: np.ndarray,
    filename: str,
    class_names: List[str],
    title: str = 'Confusion Matrix'
) -> None:
    """Save confusion matrix as high-quality PNG image."""
    plt.figure(figsize=Config.FIG_SIZE)

    sns.heatmap(
        cm, annot=True, fmt='d', cmap='Blues',
        xticklabels=class_names, yticklabels=class_names,
        cbar=True, square=True, linewidths=1, linecolor='gray'
    )

    plt.title(title, fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('True Label', fontsize=14, fontweight='bold')
    plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(filename, dpi=Config.DPI, bbox_inches='tight')
    plt.close()


def print_section_header(title: str, logger: logging.Logger, level: int = 1) -> None:
    """Print professional section header."""
    logger.info("")

    if level == 1:
        logger.info("=" * 80)
        logger.info(title)
        logger.info("=" * 80)
    else:
        logger.info("-" * 80)
        logger.info(title)
        logger.info("-" * 80)


def format_metrics_table(metrics: Dict[str, float]) -> str:
    """Format metrics dictionary as readable table."""
    lines = []
    lines.append("-" * 40)
    for metric, value in metrics.items():
        lines.append(f"  {metric:<20} {value:>8.4f}")
    lines.append("-" * 40)
    return "\n".join(lines)


class DataManager:
    """Handle data loading and preprocessing operations."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def load_data(self) -> Tuple[pd.DataFrame, pd.Series, pd.Series]:
        """Load dataset from Excel file."""
        self.logger.info(f"Loading data: {self.config.DATA_PATH}")

        try:
            df = pd.read_excel(self.config.DATA_PATH)
            self.logger.info(f"Successfully loaded {len(df)} samples")
        except Exception as e:
            self.logger.error(f"Failed to load data: {str(e)}")
            raise

        X = df[self.config.FEATURES].copy()
        y = df[self.config.TARGET].copy()
        sira = df['Sira'].copy() if 'Sira' in df.columns else pd.Series(range(len(df)), name='Sira')

        self._log_dataset_info(X, y)
        return X, y, sira

    def _log_dataset_info(self, X: pd.DataFrame, y: pd.Series) -> None:
        """Log dataset statistics."""
        self.logger.info("")
        self.logger.info("Dataset Statistics:")
        self.logger.info(f"  Total samples: {len(X)}")
        self.logger.info(f"  Number of features: {len(self.config.FEATURES)}")

        class_dist = y.value_counts()
        self.logger.info("  Class distribution:")
        for cls, count in class_dist.items():
            percentage = count / len(y) * 100
            self.logger.info(f"    Class {cls}: {count:3d} samples ({percentage:5.1f}%)")


class KANTrainer:
    """KAN model training with comprehensive grid search."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def train_single_fold(
        self,
        model: torch.nn.Module,
        train_loader: TorchDataLoader,
        val_loader: TorchDataLoader,
        optimizer: torch.optim.Optimizer,
        criterion: torch.nn.Module,
        fold: int,
        verbose: bool = False
    ) -> Dict[str, Any]:
        """Train model for a single fold."""
        early_stopping = EarlyStopping(
            patience=self.config.EARLY_STOPPING_PATIENCE,
            verbose=verbose,
            delta=self.config.EARLY_STOPPING_DELTA
        )

        history = {
            'train_loss': [],
            'val_loss': [],
            'epochs_trained': 0,
            'early_stopped': False
        }

        for epoch in range(self.config.NUM_EPOCHS):
            model.train()
            train_loss = 0.0
            for X_batch, y_batch in train_loader:
                optimizer.zero_grad()
                outputs = model(X_batch).squeeze()
                loss = criterion(outputs, y_batch)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()

            avg_train_loss = train_loss / len(train_loader)
            history['train_loss'].append(avg_train_loss)

            model.eval()
            val_loss = 0.0
            with torch.no_grad():
                for X_val, y_val in val_loader:
                    val_outputs = model(X_val).squeeze()
                    loss = criterion(val_outputs, y_val)
                    val_loss += loss.item()

            avg_val_loss = val_loss / len(val_loader)
            history['val_loss'].append(avg_val_loss)

            early_stopping(avg_val_loss, model)
            if early_stopping.early_stop:
                history['epochs_trained'] = epoch + 1
                history['early_stopped'] = True
                break
        else:
            history['epochs_trained'] = self.config.NUM_EPOCHS

        return history

    def evaluate_model(
        self,
        model: torch.nn.Module,
        X_tensor: torch.Tensor,
        y_tensor: torch.Tensor
    ) -> Tuple[np.ndarray, np.ndarray, Dict[str, float]]:
        """Evaluate model and calculate metrics."""
        model.eval()
        with torch.no_grad():
            outputs = model(X_tensor).squeeze()
            probabilities = torch.sigmoid(outputs).cpu().numpy()
            predictions = (probabilities > 0.5).astype(int)
            actuals = y_tensor.cpu().numpy()

        metrics = calculate_all_metrics(actuals, predictions)
        return predictions, probabilities, metrics

    def grid_search(
        self,
        X: pd.DataFrame,
        y: pd.Series,
        sira: pd.Series
    ) -> Dict[str, Any]:
        """Perform grid search for optimal hyperparameters."""
        print_section_header("HYPERPARAMETER OPTIMIZATION VIA GRID SEARCH", self.logger)

        y_numeric = y.replace({'D': 0, 'E': 1})

        param_combinations = list(product(
            self.config.GRID_SEARCH_PARAMS['layers_hidden'],
            self.config.GRID_SEARCH_PARAMS['grid_size'],
            self.config.GRID_SEARCH_PARAMS['spline_order'],
            self.config.GRID_SEARCH_PARAMS['scale_base'],
            self.config.GRID_SEARCH_PARAMS['scale_spline'],
            self.config.GRID_SEARCH_PARAMS['batch_size'],
            self.config.GRID_SEARCH_PARAMS['optimizer']
        ))

        total_configs = len(param_combinations)
        self.logger.info(f"Total hyperparameter configurations: {total_configs}")
        self.logger.info(f"Cross-validation folds per configuration: {self.config.N_FOLDS}")
        self.logger.info(f"Total training iterations: {total_configs * self.config.N_FOLDS}")
        self.logger.info("")

        kf = KFold(n_splits=self.config.N_FOLDS, shuffle=False)

        grid_results = []
        best_accuracy = 0.0
        best_params = None
        best_config_id = None

        start_time = datetime.now()

        pbar = tqdm(
            enumerate(param_combinations, 1),
            total=total_configs,
            desc="Grid Search Progress",
            ncols=120,
            bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}]'
        )

        for config_id, (layers_hidden, grid_size, spline_order, scale_base,
                       scale_spline, batch_size, optimizer_name) in pbar:

            self.logger.debug("")
            self.logger.debug(f"Configuration ID: {config_id}/{total_configs}")
            self.logger.debug(f"  Architecture: {layers_hidden}")
            self.logger.debug(f"  Grid size: {grid_size}, Spline order: {spline_order}")
            self.logger.debug(f"  Scale base: {scale_base}, Scale spline: {scale_spline}")
            self.logger.debug(f"  Batch size: {batch_size}, Optimizer: {optimizer_name}")

            fold_accuracies = []

            for fold, (train_idx, test_idx) in enumerate(kf.split(X), 1):
                X_train = X.iloc[train_idx].copy()
                X_test = X.iloc[test_idx].copy()
                y_train = y_numeric.iloc[train_idx].copy()
                y_test = y_numeric.iloc[test_idx].copy()

                scaler = StandardScaler()
                X_train_scaled = scaler.fit_transform(X_train)
                X_test_scaled = scaler.transform(X_test)

                X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
                y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32)
                X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
                y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)

                train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
                train_loader = TorchDataLoader(train_dataset, batch_size=batch_size, shuffle=True)

                val_dataset = TensorDataset(X_test_tensor, y_test_tensor)
                val_loader = TorchDataLoader(val_dataset, batch_size=batch_size, shuffle=False)

                try:
                    model = KAN(
                        layers_hidden=list(layers_hidden),
                        grid_size=grid_size,
                        spline_order=spline_order,
                        scale_base=scale_base,
                        scale_spline=scale_spline
                    )
                except NameError:
                    self.logger.error("KAN class not found. Please ensure KAN is properly imported.")
                    raise

                if optimizer_name == 'Adam':
                    optimizer = optim.Adam(model.parameters(), lr=self.config.LEARNING_RATE)
                else:
                    optimizer = optim.SGD(model.parameters(), lr=self.config.LEARNING_RATE, momentum=0.9)

                criterion = torch.nn.BCEWithLogitsLoss()

                history = self.train_single_fold(
                    model=model,
                    train_loader=train_loader,
                    val_loader=val_loader,
                    optimizer=optimizer,
                    criterion=criterion,
                    fold=fold,
                    verbose=False
                )

                _, _, metrics = self.evaluate_model(model, X_test_tensor, y_test_tensor)
                fold_accuracies.append(metrics['Accuracy'])

                self.logger.debug(f"  Fold {fold:2d}: Accuracy = {metrics['Accuracy']:.4f}")

            mean_accuracy = np.mean(fold_accuracies)
            std_accuracy = np.std(fold_accuracies)
            min_accuracy = np.min(fold_accuracies)
            max_accuracy = np.max(fold_accuracies)

            grid_results.append({
                'configuration_id': config_id,
                'architecture': str(layers_hidden),
                'grid_size': grid_size,
                'spline_order': spline_order,
                'scale_base': scale_base,
                'scale_spline': scale_spline,
                'batch_size': batch_size,
                'optimizer': optimizer_name,
                'mean_accuracy': mean_accuracy,
                'std_accuracy': std_accuracy,
                'min_accuracy': min_accuracy,
                'max_accuracy': max_accuracy
            })

            if mean_accuracy > best_accuracy:
                best_accuracy = mean_accuracy
                best_config_id = config_id
                best_params = {
                    'layers_hidden': list(layers_hidden),
                    'grid_size': grid_size,
                    'spline_order': spline_order,
                    'scale_base': scale_base,
                    'scale_spline': scale_spline,
                    'batch_size': batch_size,
                    'optimizer': optimizer_name
                }

                self.logger.info("")
                self.logger.info("-" * 80)
                self.logger.info("NEW BEST CONFIGURATION FOUND")
                self.logger.info("-" * 80)
                self.logger.info(f"  Configuration ID: {config_id}/{total_configs}")
                self.logger.info(f"  Mean Accuracy: {mean_accuracy:.4f} (±{std_accuracy:.4f})")
                self.logger.info(f"  Accuracy Range: [{min_accuracy:.4f}, {max_accuracy:.4f}]")
                self.logger.info("-" * 80)

            elapsed_time = (datetime.now() - start_time).total_seconds()
            avg_time_per_config = elapsed_time / config_id
            remaining_configs = total_configs - config_id
            eta_seconds = avg_time_per_config * remaining_configs
            eta_str = str(timedelta(seconds=int(eta_seconds)))

            pbar.set_postfix_str(
                f"Best: {best_accuracy:.4f} | Current: {mean_accuracy:.4f} | ETA: {eta_str}",
                refresh=True
            )

        pbar.close()

        total_time = (datetime.now() - start_time).total_seconds()
        total_time_str = str(timedelta(seconds=int(total_time)))

        print_section_header("GRID SEARCH COMPLETED", self.logger)
        self.logger.info("")
        self.logger.info("Summary of Grid Search Results:")
        self.logger.info("-" * 80)
        self.logger.info(f"Total configurations evaluated: {total_configs}")
        self.logger.info(f"Total training time: {total_time_str}")
        self.logger.info(f"Average time per configuration: {total_time/total_configs:.1f}s")
        self.logger.info("")
        self.logger.info(f"Best configuration ID: #{best_config_id}")
        self.logger.info(f"Best cross-validation accuracy: {best_accuracy:.4f}")
        self.logger.info("-" * 80)
        self.logger.info("")
        self.logger.info("Optimal Hyperparameters:")
        self.logger.info("-" * 80)
        for param_name, param_value in best_params.items():
            self.logger.info(f"  {param_name:22s}: {param_value}")
        self.logger.info("-" * 80)

        self.logger.info("")
        self.logger.info("Top 5 Configurations by Accuracy:")
        self.logger.info("-" * 80)
        sorted_results = sorted(grid_results, key=lambda x: x['mean_accuracy'], reverse=True)
        for rank, result in enumerate(sorted_results[:5], 1):
            self.logger.info(
                f"  {rank}. Config #{result['configuration_id']:4d} | "
                f"Accuracy: {result['mean_accuracy']:.4f} (±{result['std_accuracy']:.4f}) | "
                f"{result['optimizer']:4s}"
            )
        self.logger.info("-" * 80)

        return {
            'best_params': best_params,
            'best_accuracy': best_accuracy,
            'best_configuration_id': best_config_id,
            'grid_results': grid_results,
            'total_training_time_seconds': total_time
        }

    def train_final_models_with_best_config(
        self,
        X: pd.DataFrame,
        y: pd.Series,
        sira: pd.Series,
        best_params: Dict[str, Any]
    ) -> Dict[str, Any]:
        """Train final models with optimal hyperparameters and save detailed results."""
        print_section_header("TRAINING FINAL MODELS WITH OPTIMAL HYPERPARAMETERS", self.logger)

        y_numeric = y.replace({'D': 0, 'E': 1})
        kf = KFold(n_splits=self.config.N_FOLDS, shuffle=False)

        fold_results = []
        fold_full_results = []
        conf_matrices = []
        fold_models = []
        fold_scalers = []
        fold_detailed_results = []

        detailed_results_dir = os.path.join(self.config.OUTPUT_PATH, 'best_config_detailed_results')
        Path(detailed_results_dir).mkdir(parents=True, exist_ok=True)

        detailed_results_file = os.path.join(detailed_results_dir, 'fold_by_fold_results.txt')

        with open(detailed_results_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write("BEST CONFIGURATION - DETAILED FOLD-BY-FOLD RESULTS\n")
            f.write("=" * 80 + "\n\n")

            f.write("Optimal Hyperparameters:\n")
            f.write("-" * 80 + "\n")
            for param_name, param_value in best_params.items():
                f.write(f"  {param_name:22s}: {param_value}\n")
            f.write("-" * 80 + "\n\n")

            for fold, (train_idx, test_idx) in enumerate(kf.split(X), 1):
                self.logger.info(f"\nTraining Fold {fold}/{self.config.N_FOLDS}")

                X_train = X.iloc[train_idx].copy()
                X_test = X.iloc[test_idx].copy()
                y_train = y_numeric.iloc[train_idx].copy()
                y_test = y_numeric.iloc[test_idx].copy()
                sira_test = sira.iloc[test_idx]

                scaler = StandardScaler()
                X_train_scaled = scaler.fit_transform(X_train)
                X_test_scaled = scaler.transform(X_test)

                X_train_tensor = torch.tensor(X_train_scaled, dtype=torch.float32)
                y_train_tensor = torch.tensor(y_train.values, dtype=torch.float32)
                X_test_tensor = torch.tensor(X_test_scaled, dtype=torch.float32)
                y_test_tensor = torch.tensor(y_test.values, dtype=torch.float32)

                train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
                train_loader = TorchDataLoader(
                    train_dataset,
                    batch_size=best_params['batch_size'],
                    shuffle=True
                )

                val_dataset = TensorDataset(X_test_tensor, y_test_tensor)
                val_loader = TorchDataLoader(
                    val_dataset,
                    batch_size=best_params['batch_size'],
                    shuffle=False
                )

                model = KAN(
                    layers_hidden=best_params['layers_hidden'],
                    grid_size=best_params['grid_size'],
                    spline_order=best_params['spline_order'],
                    scale_base=best_params['scale_base'],
                    scale_spline=best_params['scale_spline']
                )

                if best_params['optimizer'] == 'Adam':
                    optimizer = optim.Adam(model.parameters(), lr=self.config.LEARNING_RATE)
                else:
                    optimizer = optim.SGD(model.parameters(), lr=self.config.LEARNING_RATE, momentum=0.9)

                criterion = torch.nn.BCEWithLogitsLoss()

                history = self.train_single_fold(
                    model, train_loader, val_loader,
                    optimizer, criterion, fold, verbose=False
                )

                pred_test, prob_test, metrics_test = self.evaluate_model(
                    model, X_test_tensor, y_test_tensor
                )

                cm_test = confusion_matrix(y_test_tensor.cpu().numpy(), pred_test)
                conf_matrices.append(cm_test)

                y_test_labels = ['D' if x == 0 else 'E' for x in y_test_tensor.cpu().numpy()]
                pred_test_labels = ['D' if x == 0 else 'E' for x in pred_test]
                class_report = classification_report(y_test_labels, pred_test_labels, target_names=['D', 'E'])

                f.write(f"\nFold {fold}:\n")
                f.write("-" * 80 + "\n")
                f.write(f"Accuracy:     {metrics_test['Accuracy']:.16f}\n")
                f.write(f"Precision:    {metrics_test['Precision']:.16f}\n")
                f.write(f"Recall:       {metrics_test['Recall']:.16f}\n")
                f.write(f"Specificity:  {metrics_test['Specificity']:.16f}\n")
                f.write(f"F1-Score:     {metrics_test['F1-Score']:.16f}\n")
                f.write(f"MCC:          {metrics_test['MCC']:.16f}\n")
                f.write("-" * 80 + "\n")
                f.write(f"\nClassification Report:\n")
                f.write(class_report)
                f.write(f"\nConfusion Matrix:\n")
                f.write(f" {cm_test}\n")
                f.write("\n" + "=" * 80 + "\n")

                fold_results.append({
                    'Sira': sira_test.values,
                    'Predicted': pred_test,
                    'Actual': y_test_tensor.cpu().numpy().astype(int),
                    'Probability': prob_test,
                    'Fold': fold
                })

                fold_detailed_results.append({
                    'fold': fold,
                    'accuracy': metrics_test['Accuracy'],
                    'precision': metrics_test['Precision'],
                    'recall': metrics_test['Recall'],
                    'specificity': metrics_test['Specificity'],
                    'f1_score': metrics_test['F1-Score'],
                    'mcc': metrics_test['MCC'],
                    'confusion_matrix': cm_test,
                    'classification_report': class_report,
                    'metrics': metrics_test
                })

                self.logger.info(f"  Test fold accuracy: {metrics_test['Accuracy']:.4f}")

                X_full_scaled = scaler.transform(X)
                X_full_tensor = torch.tensor(X_full_scaled, dtype=torch.float32)
                y_full_tensor = torch.tensor(y_numeric.values, dtype=torch.float32)

                pred_full, prob_full, metrics_full = self.evaluate_model(
                    model, X_full_tensor, y_full_tensor
                )

                cm_full = confusion_matrix(y_full_tensor.cpu().numpy(), pred_full)

                fold_full_results.append({
                    'Fold': fold,
                    'Predictions': pred_full,
                    'Probabilities': prob_full,
                    'Confusion_Matrix': cm_full,
                    **metrics_full
                })

                fold_models.append(model)
                fold_scalers.append(scaler)

                self.logger.info(f"  Full dataset accuracy: {metrics_full['Accuracy']:.4f}")

            total_cm = np.sum(conf_matrices, axis=0)
            f.write("\n" + "=" * 80 + "\n")
            f.write("TOTAL CONFUSION MATRIX (All Folds Combined):\n")
            f.write("=" * 80 + "\n")
            f.write(f" {total_cm}\n")
            f.write("\n")

            tn, fp, fn, tp = total_cm.ravel()
            overall_accuracy = (tp + tn) / (tp + tn + fp + fn)
            overall_precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            overall_recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            overall_specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
            overall_f1 = 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0

            overall_mcc = matthews_corrcoef(
                np.concatenate([fold_data['Actual'] for fold_data in fold_results]),
                np.concatenate([fold_data['Predicted'] for fold_data in fold_results])
            )

            f.write("\nOVERALL METRICS (Cross-Validation):\n")
            f.write("-" * 80 + "\n")
            f.write(f"Accuracy:     {overall_accuracy:.16f}\n")
            f.write(f"Precision:    {overall_precision:.16f}\n")
            f.write(f"Recall:       {overall_recall:.16f}\n")
            f.write(f"Specificity:  {overall_specificity:.16f}\n")
            f.write(f"F1-Score:     {overall_f1:.16f}\n")
            f.write(f"MCC:          {overall_mcc:.16f}\n")
            f.write("-" * 80 + "\n")

        self.logger.info(f"\nDetailed results saved to: {detailed_results_file}")

        best_fold_idx = max(range(len(fold_detailed_results)),
                           key=lambda i: fold_detailed_results[i]['accuracy'])
        best_fold_num = best_fold_idx + 1
        best_fold_accuracy = fold_detailed_results[best_fold_idx]['accuracy']

        self.logger.info("")
        self.logger.info("=" * 80)
        self.logger.info(f"BEST FOLD IDENTIFIED: Fold {best_fold_num}")
        self.logger.info(f"Best Fold Test Accuracy: {best_fold_accuracy:.4f}")
        self.logger.info("=" * 80)

        best_model_dir = os.path.join(self.config.OUTPUT_PATH, 'best_model')
        Path(best_model_dir).mkdir(parents=True, exist_ok=True)

        best_model_path = os.path.join(best_model_dir, f'best_model_fold{best_fold_num}.pth')
        best_scaler_path = os.path.join(best_model_dir, f'best_scaler_fold{best_fold_num}.pkl')
        best_params_path = os.path.join(best_model_dir, 'best_hyperparameters.pkl')

        torch.save(fold_models[best_fold_idx].state_dict(), best_model_path)
        with open(best_scaler_path, 'wb') as f:
            pickle.dump(fold_scalers[best_fold_idx], f)
        with open(best_params_path, 'wb') as f:
            pickle.dump(best_params, f)

        self.logger.info(f"Best model saved to: {best_model_path}")
        self.logger.info(f"Best scaler saved to: {best_scaler_path}")
        self.logger.info(f"Best hyperparameters saved to: {best_params_path}")

        best_model = fold_models[best_fold_idx]
        best_scaler = fold_scalers[best_fold_idx]

        X_full_scaled = best_scaler.transform(X)
        X_full_tensor = torch.tensor(X_full_scaled, dtype=torch.float32)
        y_full_tensor = torch.tensor(y_numeric.values, dtype=torch.float32)

        pred_full_best, prob_full_best, metrics_full_best = self.evaluate_model(
            best_model, X_full_tensor, y_full_tensor
        )

        best_predictions_df = pd.DataFrame({
            'Sira': sira.values,
            'Actual': y.values,
            'Predicted': ['D' if p == 0 else 'E' for p in pred_full_best],
            'Probability_Class_0': 1 - prob_full_best,
            'Probability_Class_1': prob_full_best
        })

        best_predictions_file = os.path.join(best_model_dir, 'best_model_full_dataset_predictions.xlsx')
        best_predictions_df.to_excel(best_predictions_file, index=False)
        self.logger.info(f"Best model predictions saved to: {best_predictions_file}")

        cm_best_full = confusion_matrix(y_full_tensor.cpu().numpy(), pred_full_best)
        cm_best_file = os.path.join(best_model_dir, 'best_model_confusion_matrix.png')
        save_confusion_matrix(
            cm_best_full, cm_best_file, self.config.CLASS_NAMES,
            f'Best Model (Fold {best_fold_num}) - Full Dataset Confusion Matrix'
        )
        self.logger.info(f"Best model confusion matrix saved to: {cm_best_file}")

        best_model_metrics_file = os.path.join(best_model_dir, 'best_model_metrics.txt')
        with open(best_model_metrics_file, 'w', encoding='utf-8') as f:
            f.write("=" * 80 + "\n")
            f.write(f"BEST MODEL - FOLD {best_fold_num}\n")
            f.write("=" * 80 + "\n\n")

            f.write("Test Fold Metrics:\n")
            f.write("-" * 80 + "\n")
            f.write(f"Accuracy:     {fold_detailed_results[best_fold_idx]['accuracy']:.16f}\n")
            f.write(f"Precision:    {fold_detailed_results[best_fold_idx]['precision']:.16f}\n")
            f.write(f"Recall:       {fold_detailed_results[best_fold_idx]['recall']:.16f}\n")
            f.write(f"Specificity:  {fold_detailed_results[best_fold_idx]['specificity']:.16f}\n")
            f.write(f"F1-Score:     {fold_detailed_results[best_fold_idx]['f1_score']:.16f}\n")
            f.write(f"MCC:          {fold_detailed_results[best_fold_idx]['mcc']:.16f}\n")
            f.write("-" * 80 + "\n\n")

            f.write("Full Dataset Metrics:\n")
            f.write("-" * 80 + "\n")
            f.write(f"Accuracy:     {metrics_full_best['Accuracy']:.16f}\n")
            f.write(f"Precision:    {metrics_full_best['Precision']:.16f}\n")
            f.write(f"Recall:       {metrics_full_best['Recall']:.16f}\n")
            f.write(f"Specificity:  {metrics_full_best['Specificity']:.16f}\n")
            f.write(f"F1-Score:     {metrics_full_best['F1-Score']:.16f}\n")
            f.write(f"MCC:          {metrics_full_best['MCC']:.16f}\n")
            f.write("-" * 80 + "\n\n")

            f.write("Confusion Matrix (Full Dataset):\n")
            f.write(f" {cm_best_full}\n")

        self.logger.info(f"Best model detailed metrics saved to: {best_model_metrics_file}")

        return {
            'fold_results': fold_results,
            'fold_full_results': fold_full_results,
            'conf_matrices': conf_matrices,
            'fold_detailed_results': fold_detailed_results,
            'best_fold_num': best_fold_num,
            'best_fold_accuracy': best_fold_accuracy,
            'best_model_path': best_model_path,
            'best_scaler_path': best_scaler_path,
            'best_predictions': best_predictions_df,
            'best_model_metrics': metrics_full_best
        }


class ResultsManager:
    """Generate comprehensive reports and visualizations."""

    def __init__(self, config: Config, logger: logging.Logger):
        self.config = config
        self.logger = logger

    def save_results(
        self,
        grid_search_results: Dict,
        training_results: Dict,
        X: pd.DataFrame,
        y: pd.Series,
        sira: pd.Series
    ) -> None:
        """Save all results and generate visualizations."""
        print_section_header("SAVING RESULTS AND GENERATING REPORTS", self.logger)

        output_excel = os.path.join(self.config.OUTPUT_PATH, 'kan_comprehensive_results.xlsx')

        with pd.ExcelWriter(output_excel, engine='openpyxl') as writer:
            grid_df = pd.DataFrame(grid_search_results['grid_results'])
            grid_df = grid_df.sort_values('mean_accuracy', ascending=False)
            grid_df.to_excel(writer, sheet_name='Grid_Search_Results', index=False)

            best_params_df = pd.DataFrame([grid_search_results['best_params']])
            best_params_df.to_excel(writer, sheet_name='Optimal_Parameters', index=False)

            cv_results_list = []
            for fold_data in training_results['fold_results']:
                fold_df = pd.DataFrame({
                    'Sira': fold_data['Sira'],
                    'Predicted': fold_data['Predicted'],
                    'Actual': fold_data['Actual'],
                    'Probability': fold_data['Probability'],
                    'Fold': fold_data['Fold']
                })
                cv_results_list.append(fold_df)

            cv_results_combined = pd.concat(cv_results_list, ignore_index=True)
            cv_results_combined = cv_results_combined.sort_values('Sira').reset_index(drop=True)
            cv_results_combined.to_excel(writer, sheet_name='CV_Predictions', index=False)

            full_results_df = pd.DataFrame([{
                'Fold': r['Fold'],
                'Accuracy': r['Accuracy'],
                'Precision': r['Precision'],
                'Recall': r['Recall'],
                'Specificity': r['Specificity'],
                'F1-Score': r['F1-Score'],
                'MCC': r['MCC']
            } for r in training_results['fold_full_results']])
            full_results_df.to_excel(writer, sheet_name='Full_Dataset_Per_Fold', index=False)

            y_numeric = y.replace({'D': 0, 'E': 1})
            total_cm = np.sum(training_results['conf_matrices'], axis=0)
            tn, fp, fn, tp = total_cm.ravel()

            cv_summary = pd.DataFrame({
                'Metric': ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC'],
                'Value': [
                    (tp + tn) / (tp + tn + fp + fn),
                    tp / (tp + fp) if (tp + fp) > 0 else 0,
                    tp / (tp + fn) if (tp + fn) > 0 else 0,
                    tn / (tn + fp) if (tn + fp) > 0 else 0,
                    2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
                    matthews_corrcoef(
                        cv_results_combined['Actual'].values,
                        cv_results_combined['Predicted'].values
                    )
                ]
            })
            cv_summary.to_excel(writer, sheet_name='CV_Summary', index=False)

            full_summary = pd.DataFrame({
                'Metric': ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC'],
                'Mean': [full_results_df[m].mean() for m in ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC']],
                'Std': [full_results_df[m].std() for m in ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC']],
                'Min': [full_results_df[m].min() for m in ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC']],
                'Max': [full_results_df[m].max() for m in ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1-Score', 'MCC']]
            })
            full_summary.to_excel(writer, sheet_name='Full_Dataset_Summary', index=False)

            best_model_info = pd.DataFrame([{
                'Best_Fold': training_results['best_fold_num'],
                'Best_Fold_Test_Accuracy': training_results['best_fold_accuracy'],
                'Best_Model_Full_Dataset_Accuracy': training_results['best_model_metrics']['Accuracy'],
                'Best_Model_Precision': training_results['best_model_metrics']['Precision'],
                'Best_Model_Recall': training_results['best_model_metrics']['Recall'],
                'Best_Model_Specificity': training_results['best_model_metrics']['Specificity'],
                'Best_Model_F1_Score': training_results['best_model_metrics']['F1-Score'],
                'Best_Model_MCC': training_results['best_model_metrics']['MCC']
            }])
            best_model_info.to_excel(writer, sheet_name='Best_Model_Info', index=False)

            for fold_data in training_results['fold_full_results']:
                fold_num = fold_data['Fold']
                fold_pred_df = pd.DataFrame({
                    'Sira': sira.values,
                    'Predicted': fold_data['Predictions'],
                    'Actual': y_numeric.values,
                    'Probability': fold_data['Probabilities']
                })
                fold_pred_df.to_excel(writer, sheet_name=f'Fold{fold_num}_Full_Predictions', index=False)

        self.logger.info(f"Comprehensive results saved to: {output_excel}")

        self._generate_visualizations(training_results)

    def _generate_visualizations(self, training_results: Dict) -> None:
        """Generate confusion matrices and visualizations."""
        self.logger.info("Generating visualizations...")

        total_cm = np.sum(training_results['conf_matrices'], axis=0)
        cm_file = os.path.join(self.config.OUTPUT_PATH, 'confusion_matrix_cv_total.png')
        save_confusion_matrix(
            total_cm, cm_file, self.config.CLASS_NAMES,
            'Cross-Validation Total Confusion Matrix'
        )

        for fold_data in training_results['fold_full_results']:
            fold_num = fold_data['Fold']
            cm = fold_data['Confusion_Matrix']
            cm_file = os.path.join(
                self.config.OUTPUT_PATH,
                f'confusion_matrix_fold{fold_num}_full_dataset.png'
            )
            save_confusion_matrix(
                cm, cm_file, self.config.CLASS_NAMES,
                f'Fold {fold_num} - Full Dataset Confusion Matrix'
            )

        self.logger.info("All visualizations generated successfully")


def main():
    """Main execution pipeline for KAN training."""

    logger = setup_logging(Config.OUTPUT_PATH)
    Config.display_config(logger)

    try:
        data_manager = DataManager(Config, logger)
        X, y, sira = data_manager.load_data()

        trainer = KANTrainer(Config, logger)

        grid_search_results = trainer.grid_search(X, y, sira)

        grid_file = os.path.join(Config.OUTPUT_PATH, 'grid_search_results.xlsx')
        grid_df = pd.DataFrame(grid_search_results['grid_results'])
        grid_df = grid_df.sort_values('mean_accuracy', ascending=False)
        grid_df.to_excel(grid_file, index=False)
        logger.info(f"\nGrid search results saved to: {grid_file}")

        training_results = trainer.train_final_models_with_best_config(
            X, y, sira,
            grid_search_results['best_params']
        )

        results_manager = ResultsManager(Config, logger)
        results_manager.save_results(
            grid_search_results,
            training_results,
            X, y, sira
        )

        print_section_header("TRAINING COMPLETED SUCCESSFULLY", logger)
        logger.info(f"\nAll outputs saved to: {Config.OUTPUT_PATH}")
        logger.info(f"Best model saved to: {training_results['best_model_path']}")
        logger.info(f"Best fold: {training_results['best_fold_num']} (Accuracy: {training_results['best_fold_accuracy']:.4f})")

    except Exception as e:
        logger.error(f"Training failed: {str(e)}", exc_info=True)
        raise

    finally:
        logger.info("")
        logger.info("=" * 80)
        logger.info(f"Session ended: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        logger.info("=" * 80)


if __name__ == "__main__":
    main()