In [5]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, callbacks
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from scipy import ndimage
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

class ADNIDataProcessor:
    def __init__(self, data_dir, labels_path):
        self.data_dir = data_dir
        self.labels_path = labels_path
        self.label_encoder = LabelEncoder()
        
    def load_labels(self):
        """Load and process the ADNI labels CSV file"""
        labels_df = pd.read_csv(self.labels_path)
        print(f"Labels shape: {labels_df.shape}")
        print(f"Columns: {labels_df.columns.tolist()}")
        
        # Check for the expected columns
        required_cols = ['Image Data ID', 'Subject', 'Group']
        for col in required_cols:
            if col not in labels_df.columns:
                print(f"Warning: Column '{col}' not found in CSV")
        
        print(f"Group distribution:\n{labels_df['Group'].value_counts()}")
        
        # Filter only the rows with valid groups (CN, MCI, AD)
        valid_groups = ['CN', 'MCI', 'AD']
        labels_df = labels_df[labels_df['Group'].isin(valid_groups)]
        print(f"After filtering valid groups: {labels_df.shape[0]} samples")
        
        return labels_df
    
    def preprocess_mri(self, nii_path, target_shape=(64, 64, 64)):
        """Enhanced MRI preprocessing with skull stripping simulation"""
        try:
            # Load NIfTI file
            nii_img = nib.load(nii_path)
            img_data = nii_img.get_fdata()
            
            # Simple brain extraction (remove background noise)
            # Create a mask for brain tissue (non-zero values)
            brain_mask = img_data > np.percentile(img_data[img_data > 0], 5)
            img_data = img_data * brain_mask
            
            # Robust normalization using percentiles
            p1, p99 = np.percentile(img_data[img_data > 0], [1, 99])
            img_data = np.clip(img_data, p1, p99)
            img_data = (img_data - p1) / (p99 - p1 + 1e-8)
            
            # Resize to target shape
            zoom_factors = [target_shape[i] / img_data.shape[i] for i in range(3)]
            img_resized = ndimage.zoom(img_data, zoom_factors, order=1)
            
            # Final normalization
            img_resized = (img_resized - np.mean(img_resized)) / (np.std(img_resized) + 1e-8)
            
            return img_resized
        except Exception as e:
            print(f"Error processing {nii_path}: {e}")
            return None
    
    def find_matching_files(self, labels_df):
        """Find MRI files that match the labels using Image Data ID"""
        matched_data = []
        
        # Use the correct column names from the ADNI CSV
        image_id_col = 'Image Data ID'
        subject_col = 'Subject'
        label_col = 'Group'
        
        print(f"Looking for files matching {len(labels_df)} entries...")
        
        for idx, row in labels_df.iterrows():
            image_id = str(row[image_id_col]).strip()
            subject_id = str(row[subject_col]).strip()
            label = str(row[label_col]).strip()
            
            # Search for matching files in the directory structure
            found = False
            for root, dirs, files in os.walk(self.data_dir):
                for file in files:
                    if file.endswith('.nii') or file.endswith('.nii.gz'):
                        # Check if the image ID is in the filename
                        if image_id in file or f"I{image_id}" in file:
                            file_path = os.path.join(root, file)
                            matched_data.append({
                                'file_path': file_path,
                                'image_id': image_id,
                                'subject_id': subject_id,
                                'label': label
                            })
                            found = True
                            break
                
                if found:
                    break
            
            # Alternative search by subject ID if image ID not found
            if not found:
                for root, dirs, files in os.walk(self.data_dir):
                    if subject_id in root:
                        for file in files:
                            if file.endswith('.nii') or file.endswith('.nii.gz'):
                                file_path = os.path.join(root, file)
                                matched_data.append({
                                    'file_path': file_path,
                                    'image_id': image_id,
                                    'subject_id': subject_id,
                                    'label': label
                                })
                                found = True
                                break
                    if found:
                        break
            
            if idx % 50 == 0:
                print(f"Processed {idx}/{len(labels_df)} entries, found {len(matched_data)} matches so far...")
        
        return matched_data
    
    def load_data(self, max_samples=None, target_shape=(64, 64, 64)):
        """Load and preprocess all data with balancing"""
        # Load labels
        labels_df = self.load_labels()
        
        # Find matching files
        matched_data = self.find_matching_files(labels_df)
        print(f"Found {len(matched_data)} matching files")
        
        # Balance classes if needed
        if max_samples:
            # Group by label and sample equally
            matched_df = pd.DataFrame(matched_data)
            balanced_data = []
            
            samples_per_class = max_samples // len(matched_df['label'].unique())
            for label in matched_df['label'].unique():
                label_data = matched_df[matched_df['label'] == label].sample(
                    n=min(samples_per_class, len(matched_df[matched_df['label'] == label])),
                    random_state=42
                ).to_dict('records')
                balanced_data.extend(label_data)
            
            matched_data = balanced_data
        
        # Load and preprocess images
        X, y = [], []
        
        for i, data in enumerate(matched_data):
            img = self.preprocess_mri(data['file_path'], target_shape)
            if img is not None:
                X.append(img)
                y.append(data['label'])
            
            if i % 50 == 0:
                print(f"Loaded {i}/{len(matched_data)} images...")
        
        X = np.array(X)
        y = np.array(y)
        
        # Encode labels
        y_encoded = self.label_encoder.fit_transform(y)
        y_categorical = to_categorical(y_encoded)
        
        print(f"Final data shape: {X.shape}")
        print(f"Labels shape: {y_categorical.shape}")
        print(f"Label mapping: {dict(zip(self.label_encoder.classes_, range(len(self.label_encoder.classes_))))}")
        
        return X, y_categorical, self.label_encoder.classes_

def create_improved_3d_cnn_model(input_shape, num_classes):
    """3D CNN model with L2 regularization, Dropout, and BatchNorm"""
    l2_reg = tf.keras.regularizers.l2(0.001)

    inputs = layers.Input(shape=input_shape)

    # First block
    x = layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(32, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.Dropout(0.1)(x)

    # Second block
    x = layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(64, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.Dropout(0.2)(x)

    # Third block
    x = layers.Conv3D(128, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(128, (3, 3, 3), activation='relu', padding='same', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling3D((2, 2, 2))(x)
    x = layers.Dropout(0.3)(x)

    # Global pooling
    gap = layers.GlobalAveragePooling3D()(x)
    gmp = layers.GlobalMaxPooling3D()(x)
    x = layers.Concatenate()([gap, gmp])

    # Dense layers
    x = layers.Dense(256, activation='relu', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.4)(x)

    x = layers.Dense(64, activation='relu', kernel_regularizer=l2_reg)(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs=inputs, outputs=outputs)
    return model

def create_data_augmentation():
    """Enhanced data augmentation for 3D medical images"""
    def augment_3d(x, y):
        # Convert to numpy if tensor
        if hasattr(x, 'numpy'):
            x = x.numpy()
        
        # Random rotation (small angles only for medical images)
        if np.random.random() > 0.7:
            angle = np.random.uniform(-5, 5)
            axes = np.random.choice([(0, 1), (0, 2), (1, 2)])
            x = ndimage.rotate(x, angle, axes=axes, reshape=False, order=1)
        
        # Random flip (only left-right for brain images)
        if np.random.random() > 0.5:
            x = np.flip(x, axis=0)
        
        # Gaussian noise
        if np.random.random() > 0.8:
            noise = np.random.normal(0, 0.05, x.shape)
            x = x + noise
        
        # Intensity scaling
        if np.random.random() > 0.8:
            scale = np.random.uniform(0.9, 1.1)
            x = x * scale
        
        return x, y
    
    return augment_3d

def train_model_with_kfold():
    """Improved training function with k-fold CV, augmentation, regularization"""
    DATA_DIR = r"C:\Users\hp\Downloads\ADNI1_Complete 1Yr 1.5T\ADNI"
    LABELS_PATH = r"C:\Users\hp\Downloads\ADNI_labels.csv"
    TARGET_SHAPE = (64, 64, 64)
    BATCH_SIZE = 4
    EPOCHS = 30

    processor = ADNIDataProcessor(DATA_DIR, LABELS_PATH)
    X, y, class_names = processor.load_data(target_shape=TARGET_SHAPE, max_samples=300)

    if len(X) == 0:
        print("No data loaded!")
        return None, None

    X = X[..., np.newaxis]
    y_labels = np.argmax(y, axis=1)
    unique, counts = np.unique(y_labels, return_counts=True)
    print(f"Class distribution: {dict(zip([class_names[i] for i in unique], counts))}")

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_scores = []
    augment_fn = create_data_augmentation()

    for fold, (train_idx, val_idx) in enumerate(skf.split(X, y_labels)):
        print(f"\n=== Fold {fold + 1} ===")

        X_train_fold, X_val_fold = X[train_idx], X[val_idx]
        y_train_fold, y_val_fold = y[train_idx], y[val_idx]

        model = create_improved_3d_cnn_model(X_train_fold.shape[1:], len(class_names))

        class_weights = {i: len(y_labels) / (len(unique) * count) for i, count in enumerate(counts)}

        model.compile(
            optimizer=optimizers.Adam(learning_rate=0.0005),
            loss='sparse_categorical_crossentropy',
            metrics=['accuracy']
        )b

        callbacks_list = [
            callbacks.EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True, verbose=1),
            callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.3, patience=4, min_lr=1e-7, verbose=1)
        ]

        # Prepare augmented training dataset
        train_dataset = tf.data.Dataset.from_tensor_slices((X_train_fold, y_train_fold))
        train_dataset = train_dataset.map(lambda x, y: tf.py_function(augment_fn, [x, y], [tf.float32, tf.float32]))
        # Convert one-hot to class indices
        y_train_indices = np.argmax(y_train_fold, axis=1)
        y_val_indices = np.argmax(y_val_fold, axis=1)
        
        # Create datasets using integer labels
        train_dataset = tf.data.Dataset.from_tensor_slices((X_train_fold, y_train_indices))
        val_dataset = tf.data.Dataset.from_tensor_slices((X_val_fold, y_val_indices))
        
        # Optional: add your augmentation function if using
        # train_dataset = train_dataset.map(lambda x, y: (tf.py_function(lambda a: augment_fn(a, y)[0], [x], tf.float32), y))
        
        train_dataset = train_dataset.shuffle(100).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
        val_dataset = val_dataset.batch(BATCH_SIZE)

        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=EPOCHS,
            callbacks=callbacks_list,
            class_weight=class_weights,
            verbose=1
        )

        val_loss, val_accuracy = model.evaluate(val_dataset, verbose=0)
        fold_scores.append(val_accuracy)
        print(f"Fold {fold + 1} Validation Accuracy: {val_accuracy:.4f}")

        if fold == 0:
            best_model = model
            best_history = history

    print(f"\nCross-validation results:")
    print(f"Mean CV Accuracy: {np.mean(fold_scores):.4f} ± {np.std(fold_scores):.4f}")

    return best_model, best_history

def evaluate_model_comprehensive(model, X_test, y_test, class_names):
    """Comprehensive model evaluation"""
    # Predictions
    y_pred = model.predict(X_test, verbose=0)
    y_pred_classes = np.argmax(y_pred, axis=1)
    y_true_classes = np.argmax(y_test, axis=1)
    
    # Classification report
    print("\nDetailed Classification Report:")
    print(classification_report(y_true_classes, y_pred_classes, 
                              target_names=class_names, digits=4))
    
    # Confusion matrix
    cm = confusion_matrix(y_true_classes, y_pred_classes)
    print("\nConfusion Matrix:")
    print(cm)
    
    # Per-class accuracy
    print("\nPer-class Accuracy:")
    for i, class_name in enumerate(class_names):
        class_acc = cm[i, i] / np.sum(cm[i, :]) if np.sum(cm[i, :]) > 0 else 0
        print(f"{class_name}: {class_acc:.4f}")
    
    return y_pred, y_pred_classes

if __name__ == "__main__":
    # Run enhanced training
    print("Starting enhanced ADNI 3D CNN training...")
    model, history = train_model_with_kfold()
    
    if model is not None:
        print("\nTraining completed successfully!")
        
        # Save model in modern format
        model.save('enhanced_adni_model.keras')
        print("Model saved as 'enhanced_adni_model.keras'")
    else:
        print("Training failed. Please check your data paths and setup.")

Starting enhanced ADNI 3D CNN training...
Labels shape: (2294, 12)
Columns: ['Image Data ID', 'Subject', 'Group', 'Sex', 'Age', 'Visit', 'Modality', 'Description', 'Type', 'Acq Date', 'Format', 'Downloaded']
Group distribution:
Group
MCI    1113
CN      705
AD      476
Name: count, dtype: int64
After filtering valid groups: 2294 samples
Looking for files matching 2294 entries...
Processed 0/2294 entries, found 1 matches so far...
Processed 50/2294 entries, found 27 matches so far...
Processed 100/2294 entries, found 50 matches so far...
Processed 150/2294 entries, found 74 matches so far...
Processed 200/2294 entries, found 106 matches so far...
Processed 250/2294 entries, found 143 matches so far...
Processed 300/2294 entries, found 176 matches so far...
Processed 350/2294 entries, found 197 matches so far...
Processed 400/2294 entries, found 222 matches so far...
Processed 450/2294 entries, found 240 matches so far...
Processed 500/2294 entries, found 273 matches so far...
Processed 