In [3]:
import argparse
import os
import logging
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.metrics import (
    accuracy_score,
    recall_score,
    f1_score,
    confusion_matrix,
    roc_auc_score,
)
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pandas as pd


def setup_logging(log_level: str = "INFO") -> None:
    """
    Configure logging for the script.

    Args:
        log_level (str, optional): Logging level. Defaults to "INFO".
    """
    numeric_level = getattr(logging, log_level.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError(f"Invalid log level: {log_level}")

    logging.basicConfig(
        level=numeric_level,
        format="%(asctime)s - %(levelname)s - %(message)s",
        handlers=[logging.StreamHandler()],
    )


def load_data(fmri_path: str, labels_path: str) -> Tuple:
    """
    Load and preprocess fMRI and label data from CSV files.

    Args:
        fmri_path (str): Path to the fMRI data CSV file.
        labels_path (str): Path to the labels CSV file.

    Returns:
        Tuple: (features as NumPy array, labels as NumPy array)
    """
    logging.info(f"Loading fMRI data from {fmri_path}")
    fmri_data = pd.read_csv(fmri_path)
    logging.info(f"Loading labels data from {labels_path}")
    labels_data = pd.read_csv(labels_path)

    # Ensure 'IID' column exists
    if 'IID' not in fmri_data.columns or 'IID' not in labels_data.columns:
        raise ValueError("Both fMRI and labels data must contain an 'IID' column.")

    # Set IID as index for both datasets
    fmri_data.set_index("IID", inplace=True)
    labels_data.set_index("IID", inplace=True)

    # Filter the fMRI data to include only IIDs present in the labels data
    filtered_fmri_data = fmri_data.loc[labels_data.index]
    logging.info(f"Filtered fMRI data shape: {filtered_fmri_data.shape}")

    # Use DIA as the label
    if "DIA" not in labels_data.columns:
        raise ValueError("Labels data must contain a 'DIA' column.")
    labels = labels_data["DIA"]

    return filtered_fmri_data.values, labels.values


def prepare_tensors(
    X: torch.Tensor,
    y: torch.Tensor,
    test_size: float = 0.2,
    val_size: float = 0.1,
    random_state: int = 42,
    device: torch.device = torch.device("cpu"),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Split data into training, validation, and testing sets and convert them to PyTorch tensors.

    Args:
        X (torch.Tensor): Feature matrix.
        y (torch.Tensor): Labels.
        test_size (float, optional): Proportion of the dataset to include in the test split. Defaults to 0.2.
        val_size (float, optional): Proportion of the training set to include in the validation split. Defaults to 0.1.
        random_state (int, optional): Seed used by the random number generator. Defaults to 42.
        device (torch.device, optional): Device to which tensors are moved. Defaults to CPU.

    Returns:
        Tuple: (X_train, X_val, X_test, y_train, y_val, y_test)
    """
    logging.info("Splitting data into training and testing sets")
    X_train_full, X_test, y_train_full, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=y
    )

    logging.info("Splitting training data into training and validation sets")
    X_train, X_val, y_train, y_val = train_test_split(
        X_train_full, y_train_full, test_size=val_size, random_state=random_state, stratify=y_train_full
    )

    # Convert to tensors and move to device
    logging.info("Converting data to PyTorch tensors")
    X_train = torch.tensor(X_train, dtype=torch.float32).to(device)
    X_val = torch.tensor(X_val, dtype=torch.float32).to(device)
    X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
    y_train = torch.tensor(y_train, dtype=torch.long).to(device)
    y_val = torch.tensor(y_val, dtype=torch.long).to(device)
    y_test = torch.tensor(y_test, dtype=torch.long).to(device)

    logging.info(f"Training set size: {X_train.shape[0]}")
    logging.info(f"Validation set size: {X_val.shape[0]}")
    logging.info(f"Testing set size: {X_test.shape[0]}")

    return X_train, X_val, X_test, y_train, y_val, y_test


class MLP(nn.Module):
    """Multi-Layer Perceptron (MLP) model."""

    def __init__(self, input_size: int, hidden_layers: list, dropout: float = 0.2):
        """
        Initialize the MLP model.

        Args:
            input_size (int): Number of input features.
            hidden_layers (list): List of integers specifying the number of neurons in each hidden layer.
            dropout (float, optional): Dropout rate. Defaults to 0.2.
        """
        super(MLP, self).__init__()
        layers = []
        previous_size = input_size

        for idx, layer_size in enumerate(hidden_layers):
            layers.append(nn.Linear(previous_size, layer_size))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout))
            previous_size = layer_size
            logging.debug(f"Added layer {idx+1}: Linear({previous_size} -> {layer_size})")

        layers.append(nn.Linear(previous_size, 2))  # 2 output classes
        self.network = nn.Sequential(*layers)
        logging.info(f"MLP architecture: {self.network}")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.network(x)


def train(
    model: nn.Module,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    scheduler: ReduceLROnPlateau,
    X_train: torch.Tensor,
    y_train: torch.Tensor,
    X_val: torch.Tensor,
    y_val: torch.Tensor,
    num_epochs: int = 10000,
    device: torch.device = torch.device("cpu"),
    checkpoint_path: str = "best_model.pth",
    log_interval: int = 1000,
) -> None:
    """
    Train the MLP model.

    Args:
        model (nn.Module): The MLP model.
        criterion (nn.Module): Loss function.
        optimizer (optim.Optimizer): Optimizer.
        scheduler (ReduceLROnPlateau): Learning rate scheduler.
        X_train (torch.Tensor): Training features.
        y_train (torch.Tensor): Training labels.
        X_val (torch.Tensor): Validation features.
        y_val (torch.Tensor): Validation labels.
        num_epochs (int, optional): Number of training epochs. Defaults to 10000.
        device (torch.device, optional): Device for computation. Defaults to CPU.
        checkpoint_path (str, optional): Path to save the best model. Defaults to "best_model.pth".
        log_interval (int, optional): Interval (in epochs) for logging. Defaults to 1000.
    """
    best_loss = float("inf")
    logging.info("Starting training")

    for epoch in tqdm(range(1, num_epochs + 1), desc="Training"):
        model.train()
        optimizer.zero_grad()

        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val)
            val_loss = criterion(val_outputs, y_val)

        # Step the scheduler based on validation loss
        scheduler.step(val_loss)

        # Save the best model
        if val_loss.item() < best_loss:
            best_loss = val_loss.item()
            torch.save(model.state_dict(), checkpoint_path)
            logging.debug(f"Saved new best model with validation loss: {best_loss:.4f}")

        # Logging
        if epoch % log_interval == 0 or epoch == 1:
            current_lr = optimizer.param_groups[0]['lr']
            logging.info(
                f"Epoch [{epoch}/{num_epochs}], "
                f"Train Loss: {loss.item():.4f}, "
                f"Val Loss: {val_loss.item():.4f}, "
                f"LR: {current_lr:.6f}"
            )

    logging.info("Training completed")


def evaluate(
    model: nn.Module,
    X_test: torch.Tensor,
    y_test: torch.Tensor,
) -> dict:
    """
    Evaluate the trained model on the test set.

    Args:
        model (nn.Module): The trained MLP model.
        X_test (torch.Tensor): Test features.
        y_test (torch.Tensor): Test labels.

    Returns:
        dict: Evaluation metrics.
    """
    logging.info("Evaluating the model on the test set")
    model.eval()
    with torch.no_grad():
        outputs = model(X_test)
        _, predicted = torch.max(outputs, 1)

    # Convert predictions and labels to numpy arrays
    y_pred = predicted.cpu().numpy()
    y_true = y_test.cpu().numpy()

    # Calculate metrics
    acc = accuracy_score(y_true, y_pred)
    sensitivity = recall_score(y_true, y_pred, pos_label=1)  # Recall for class 1
    f1 = f1_score(y_true, y_pred)

    # Calculate specificity (recall for class 0)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    specificity = tn / (tn + fp)

    # AUROC (binary case)
    auroc = roc_auc_score(y_true, predicted.cpu())

    metrics = {
        "Accuracy": acc,
        "Sensitivity (Recall for class 1)": sensitivity,
        "Specificity (Recall for class 0)": specificity,
        "F1-Score": f1,
        "AUROC": auroc,
    }

    logging.info("Evaluation Metrics:")
    for metric, value in metrics.items():
        logging.info(f"{metric}: {value:.4f}")

    return metrics


def parse_arguments() -> argparse.Namespace:
    """
    Parse command-line arguments.

    Returns:
        argparse.Namespace: Parsed arguments.
    """
    parser = argparse.ArgumentParser(
        description="Train an MLP model for ADHD classification with auto learning rate adjustment."
    )

    # Data paths
    parser.add_argument(
        "--fmri_path",
        type=str,
        required=True,
        help="Path to the fMRI data CSV file.",
    )
    parser.add_argument(
        "--labels_path",
        type=str,
        required=True,
        help="Path to the labels CSV file.",
    )

    # Training parameters
    parser.add_argument(
        "--num_epochs",
        type=int,
        default=10000,
        help="Number of training epochs.",
    )
    parser.add_argument(
        "--initial_lr",
        type=float,
        default=0.0001,
        help="Initial learning rate for the optimizer.",
    )
    parser.add_argument(
        "--patience",
        type=int,
        default=10,
        help="Number of epochs with no improvement after which learning rate will be reduced.",
    )
    parser.add_argument(
        "--factor",
        type=float,
        default=0.5,
        help="Factor by which the learning rate will be reduced.",
    )
    parser.add_argument(
        "--min_lr",
        type=float,
        default=1e-6,
        help="Minimum learning rate.",
    )
    parser.add_argument(
        "--hidden_layers",
        type=int,
        nargs='+',
        default=[512, 256, 128, 64, 32],
        help="List of hidden layer sizes. Example: --hidden_layers 512 256 128",
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=0.2,
        help="Dropout rate between layers.",
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=None,
        help="Batch size for training. If not set, uses full batch.",
    )

    # Checkpoint and logging
    parser.add_argument(
        "--checkpoint_path",
        type=str,
        default="best_model.pth",
        help="Path to save the best model checkpoint.",
    )
    parser.add_argument(
        "--log_level",
        type=str,
        default="INFO",
        choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
        help="Logging level.",
    )

    # Device configuration
    parser.add_argument(
        "--device",
        type=str,
        default="cpu",
        choices=["cpu", "cuda"],
        help="Device to run the training on ('cpu' or 'cuda').",
    )

    return parser.parse_args()


def main():
    # Parse command-line arguments
    args = parse_arguments()

    # Setup logging
    setup_logging(args.log_level)

    # Set device
    if args.device == "cuda" and not torch.cuda.is_available():
        logging.warning("CUDA is not available. Falling back to CPU.")
        device = torch.device("cpu")
    else:
        device = torch.device(args.device)
    logging.info(f"Using device: {device}")

    # Validate data paths
    if not os.path.isfile(args.fmri_path):
        logging.error(f"fMRI data file not found at {args.fmri_path}")
        raise FileNotFoundError(f"fMRI data file not found at {args.fmri_path}")
    if not os.path.isfile(args.labels_path):
        logging.error(f"Labels data file not found at {args.labels_path}")
        raise FileNotFoundError(f"Labels data file not found at {args.labels_path}")

    # Load and preprocess data
    X, y = load_data(args.fmri_path, args.labels_path)

    # Prepare tensors
    X_train, X_val, X_test, y_train, y_val, y_test = prepare_tensors(
        X, y, device=device
    )

    # Initialize the model
    input_size = X_train.shape[1]
    model = MLP(input_size, hidden_layers=args.hidden_layers, dropout=args.dropout).to(device)

    # Define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.initial_lr)

    # Define learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer,
        mode="min",
        factor=args.factor,
        patience=args.patience,
        verbose=True,
        min_lr=args.min_lr,
    )

    # Train the model
    train(
        model=model,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        X_train=X_train,
        y_train=y_train,
        X_val=X_val,
        y_val=y_val,
        num_epochs=args.num_epochs,
        device=device,
        checkpoint_path=args.checkpoint_path,
        log_interval=1000,  # Can also be made an argument if needed
    )

    # Load the best model
    logging.info("Loading the best model for evaluation")
    model.load_state_dict(torch.load(args.checkpoint_path))

    # Evaluate the model
    metrics = evaluate(model, X_test, y_test)

    # Optionally, save metrics to a file
    metrics_path = os.path.splitext(args.checkpoint_path)[0] + "_metrics.txt"
    with open(metrics_path, 'w') as f:
        for metric, value in metrics.items():
            f.write(f"{metric}: {value:.4f}\n")
    logging.info(f"Saved evaluation metrics to {metrics_path}")


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] --fmri_path FMRI_PATH --labels_path
                             LABELS_PATH [--num_epochs NUM_EPOCHS]
                             [--initial_lr INITIAL_LR] [--patience PATIENCE]
                             [--factor FACTOR] [--min_lr MIN_LR]
                             [--hidden_layers HIDDEN_LAYERS [HIDDEN_LAYERS ...]]
                             [--dropout DROPOUT] [--batch_size BATCH_SIZE]
                             [--checkpoint_path CHECKPOINT_PATH]
                             [--device {cpu,cuda}]
ipykernel_launcher.py: error: ambiguous option: --f=/home/songlinzhao/.local/share/jupyter/runtime/kernel-v3e959ad9022331b61b784349e060baa32158abfd9.json could match --fmri_path, --factor


SystemExit: 2