In [None]:
import os
import numpy as np
import nibabel as nib
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import tensorflow as tf
from tensorflow.keras import layers, models, mixed_precision
from scipy.ndimage import zoom, rotate
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tensorflow.keras.regularizers import l2
from collections import Counter
from sklearn.metrics import roc_curve, auc, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

class AdamW(tf.keras.optimizers.Optimizer):
    """Custom AdamW optimizer implementation."""
    
    def __init__(
        self,
        learning_rate=0.001,
        weight_decay=0.004,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7,
        amsgrad=False,
        name="AdamW",
        **kwargs
    ):
        """Initialize optimizer."""
        super().__init__(name, **kwargs)
        self._set_hyper("learning_rate", learning_rate)
        self._set_hyper("weight_decay", weight_decay)
        self._set_hyper("beta_1", beta_1)
        self._set_hyper("beta_2", beta_2)
        self.epsilon = epsilon
        self.amsgrad = amsgrad
        
    def _create_slots(self, var_list):
        """Create slots for optimizer variables."""
        # Create slots for first and second moments
        for var in var_list:
            self.add_slot(var, "m")  # First moment
            self.add_slot(var, "v")  # Second moment
            if self.amsgrad:
                self.add_slot(var, "vhat")  # Moved avg of second moment
                
    @tf.function
    def _resource_apply_dense(self, grad, var, apply_state=None):
        """Apply gradients to dense variables."""
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        weight_decay = self._get_hyper("weight_decay", var_dtype)
        beta_1_t = self._get_hyper("beta_1", var_dtype)
        beta_2_t = self._get_hyper("beta_2", var_dtype)
        epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)
        beta_1_power = tf.pow(beta_1_t, local_step)
        beta_2_power = tf.pow(beta_2_t, local_step)

        # Get slots
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        # Adam updates
        m_t = m.assign(
            beta_1_t * m + (1.0 - beta_1_t) * grad,
            use_locking=self._use_locking
        )
        v_t = v.assign(
            beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad),
            use_locking=self._use_locking
        )

        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat_t = vhat.assign(
                tf.maximum(vhat, v_t),
                use_locking=self._use_locking
            )
            denom = tf.sqrt(vhat_t) + epsilon_t
        else:
            denom = tf.sqrt(v_t) + epsilon_t

        # Bias correction
        lr_corr = lr_t * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)
        
        # Update variable with momentum and weight decay
        var_update = var.assign_sub(
            lr_corr * m_t / denom + lr_t * weight_decay * var,
            use_locking=self._use_locking
        )

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(vhat_t)
            
        return tf.group(*updates)
        
    def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
        """Apply gradients to sparse variables."""
        var_dtype = var.dtype.base_dtype
        lr_t = self._decayed_lr(var_dtype)
        weight_decay = self._get_hyper("weight_decay", var_dtype)
        beta_1_t = self._get_hyper("beta_1", var_dtype)
        beta_2_t = self._get_hyper("beta_2", var_dtype)
        epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
        local_step = tf.cast(self.iterations + 1, var_dtype)
        beta_1_power = tf.pow(beta_1_t, local_step)
        beta_2_power = tf.pow(beta_2_t, local_step)

        # Get slots
        m = self.get_slot(var, "m")
        v = self.get_slot(var, "v")

        # Sparse updates for momentum
        m_scaled_g_values = grad * (1 - beta_1_t)
        m_t = m.assign(m * beta_1_t, use_locking=self._use_locking)
        with tf.control_dependencies([m_t]):
            m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

        # Sparse updates for variance
        v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
        v_t = v.assign(v * beta_2_t, use_locking=self._use_locking)
        with tf.control_dependencies([v_t]):
            v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

        if self.amsgrad:
            vhat = self.get_slot(var, "vhat")
            vhat_t = vhat.assign(
                tf.maximum(vhat, v_t),
                use_locking=self._use_locking
            )
            denom = tf.sqrt(vhat_t) + epsilon_t
        else:
            denom = tf.sqrt(v_t) + epsilon_t

        # Bias correction
        lr_corr = lr_t * tf.sqrt(1 - beta_2_power) / (1 - beta_1_power)
        
        # Update variable with momentum
        var_update = var.assign_sub(
            lr_corr * m_t / denom + lr_t * weight_decay * var,
            use_locking=self._use_locking
        )

        updates = [var_update, m_t, v_t]
        if self.amsgrad:
            updates.append(vhat_t)
            
        return tf.group(*updates)

    def get_config(self):
        """Return configuration of the optimizer."""
        config = super().get_config()
        config.update({
            "learning_rate": self._serialize_hyperparameter("learning_rate"),
            "weight_decay": self._serialize_hyperparameter("weight_decay"),
            "beta_1": self._serialize_hyperparameter("beta_1"),
            "beta_2": self._serialize_hyperparameter("beta_2"),
            "epsilon": self.epsilon,
            "amsgrad": self.amsgrad,
        })
        return config
    
class ChannelAttention(layers.Layer):
    def __init__(self, ratio=8, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        channel = input_shape[-1]
        self.shared_dense1 = layers.Dense(channel // self.ratio,
                                        activation='relu',
                                        kernel_initializer='he_normal',
                                        use_bias=True,
                                        bias_initializer='zeros')
        self.shared_dense2 = layers.Dense(channel,
                                        kernel_initializer='he_normal',
                                        use_bias=True,
                                        bias_initializer='zeros')

    def call(self, inputs):
        # Average pooling
        avg_pool = tf.reduce_mean(inputs, axis=[1, 2, 3], keepdims=True)
        avg_pool = self.shared_dense1(avg_pool)
        avg_pool = self.shared_dense2(avg_pool)

        # Max pooling
        max_pool = tf.reduce_max(inputs, axis=[1, 2, 3], keepdims=True)
        max_pool = self.shared_dense1(max_pool)
        max_pool = self.shared_dense2(max_pool)

        attention = tf.nn.sigmoid(avg_pool + max_pool)
        return attention

class SpatialAttention(layers.Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size

    def build(self, input_shape):
        self.conv = layers.Conv3D(1, self.kernel_size, 
                                padding='same',
                                kernel_initializer='he_normal',
                                use_bias=False)

    def call(self, inputs):
        # Average pooling along channel
        avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
        # Max pooling along channel
        max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
        # Concatenate
        concat = tf.concat([avg_pool, max_pool], axis=-1)
        # Apply convolution
        spatial_attention = self.conv(concat)
        return tf.nn.sigmoid(spatial_attention)

class ChannelAttention(layers.Layer):
    def __init__(self, ratio=8, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        channel = input_shape[-1]
        self.shared_dense1 = layers.Dense(channel // self.ratio,
                                        activation='relu',
                                        kernel_initializer='he_normal',
                                        use_bias=True,
                                        bias_initializer='zeros')
        self.shared_dense2 = layers.Dense(channel,
                                        kernel_initializer='he_normal',
                                        use_bias=True,
                                        bias_initializer='zeros')

    def call(self, inputs):
        # Average pooling
        avg_pool = tf.reduce_mean(inputs, axis=[1, 2, 3], keepdims=True)
        avg_pool = self.shared_dense1(avg_pool)
        avg_pool = self.shared_dense2(avg_pool)

        # Max pooling
        max_pool = tf.reduce_max(inputs, axis=[1, 2, 3], keepdims=True)
        max_pool = self.shared_dense1(max_pool)
        max_pool = self.shared_dense2(max_pool)

        attention = tf.nn.sigmoid(avg_pool + max_pool)
        return attention

class SpatialAttention(layers.Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size

    def build(self, input_shape):
        self.conv = layers.Conv3D(1, self.kernel_size, 
                                padding='same',
                                kernel_initializer='he_normal',
                                use_bias=False)

    def call(self, inputs):
        # Average pooling along channel
        avg_pool = tf.reduce_mean(inputs, axis=-1, keepdims=True)
        # Max pooling along channel
        max_pool = tf.reduce_max(inputs, axis=-1, keepdims=True)
        # Concatenate
        concat = tf.concat([avg_pool, max_pool], axis=-1)
        # Apply convolution
        spatial_attention = self.conv(concat)
        return tf.nn.sigmoid(spatial_attention)

def create_attention_block(x, filters):
    # Channel attention
    channel_attention = ChannelAttention()(x)
    x = layers.Multiply()([x, channel_attention])
    
    # Spatial attention
    spatial_attention = SpatialAttention()(x)
    x = layers.Multiply()([x, spatial_attention])
    
    return x
    

def create_residual_block(x, filters):
    skip = layers.Conv3D(filters, kernel_size=1, padding='same')(x)
    skip = layers.BatchNormalization()(skip)
    
    # First convolution block
    x = layers.Conv3D(filters, kernel_size=3, padding='same', kernel_regularizer=l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    
    # Second convolution block
    x = layers.Conv3D(filters, kernel_size=3, padding='same', kernel_regularizer=l2(1e-4))(x)
    x = layers.BatchNormalization()(x)
    
    # Apply attention
    x = create_attention_block(x, filters)
    
    # Residual connection
    x = layers.Add()([x, skip])
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.MaxPooling3D(pool_size=2)(x)
    
    return x 

class SingleScanDataGenerator:
    def __init__(self, base_path, target_shape=(128, 128, 128)):
        self.base_path = base_path
        self.target_shape = target_shape
        self.classes = ['AD', 'CN']  # Updated classes
        
    def preprocess_image(self, img):
        """Enhanced preprocessing pipeline"""
        # Handle NaN and Inf values
        img = np.nan_to_num(img)
        
        # Resize with better interpolation
        zoom_factors = [self.target_shape[i] / img.shape[i] for i in range(3)]
        img_resized = zoom(img, zoom_factors, order=2)  # Using order=2 for better interpolation
        
        # Robust normalization with outlier handling
        p1, p99 = np.percentile(img_resized, (1, 99))
        img_normalized = np.clip(img_resized, p1, p99)
        img_normalized = (img_normalized - p1) / (p99 - p1)
        
        # Z-score normalization
        mean = np.mean(img_normalized)
        std = np.std(img_normalized) + 1e-10
        img_normalized = (img_normalized - mean) / std
        
        return img_normalized
    
    def augment_image(self, img):
        """Enhanced data augmentation for training"""
        # Random rotation in multiple planes
        for axis in [(0,1), (0,2), (1,2)]:
            if np.random.random() > 0.5:
                angle = np.random.uniform(-20, 20)
                img = rotate(img, angle, axes=axis, reshape=False)
        
        # Random flips on all axes
        for axis in [0, 1, 2]:
            if np.random.random() > 0.5:
                img = np.flip(img, axis=axis)
        
        # Random intensity scaling
        scale = np.random.uniform(0.85, 1.15)
        img = img * scale
        
        # Random brightness adjustment
        brightness = np.random.uniform(-0.1, 0.1)
        img = img + brightness
        
        # Random gaussian noise
        noise = np.random.normal(0, 0.02, img.shape)
        img = img + noise
        
        # Random gamma correction
        gamma = np.random.uniform(0.8, 1.2)
        img = np.sign(img) * np.abs(img) ** gamma
        
        # Ensure values are in valid range
        img = np.clip(img, -1, 1)
        
        return img
    
    def get_data(self, subset='train'):
        """Get paths and labels with enhanced error handling"""
        data_paths = []
        labels = []
        class_counts = {'AD': 0, 'CN': 0}
        
        subset_path = os.path.join(self.base_path,'ADNI', subset)  # Updated path
        if not os.path.exists(subset_path):
            raise ValueError(f"Data path not found: {subset_path}")
            
        for condition in self.classes:
            condition_path = os.path.join(subset_path, condition)
            if os.path.exists(condition_path):
                scans = [f for f in os.listdir(condition_path) 
                        if f.endswith(('.nii', '.nii.gz'))]
                
                for scan in scans:
                    scan_path = os.path.join(condition_path, scan)
                    try:
                        # Verify file can be loaded
                        nib.load(scan_path)
                        data_paths.append(scan_path)
                        labels.append(condition)
                        class_counts[condition] += 1
                    except Exception as e:
                        print(f"Error loading {scan_path}: {e}")
                        continue
        
        if not data_paths:
            raise ValueError("No valid scans found in the data directory")
            
        print("\nClass distribution:", class_counts)
        print(f"Total scans: {len(data_paths)}")
        return data_paths, labels
    
    def create_dataset(self, data_paths, labels, batch_size, augment=True, label_encoder=None):
        """Create optimized dataset for GPU training"""
        if label_encoder is None:
            label_encoder = LabelEncoder()
            label_encoder.fit(labels)
        
        encoded_labels = label_encoder.transform(labels)
        num_classes = len(label_encoder.classes_)
        
        def generator():
            while True:
                # Shuffle with numpy for better randomization
                indices = np.random.permutation(len(data_paths))
                for idx in indices:
                    try:
                        # Load and process data on CPU
                        with tf.device('/CPU:0'):
                            img = nib.load(data_paths[idx]).get_fdata()
                            img = self.preprocess_image(img)
                            
                            if augment:
                                img = self.augment_image(img)
                            
                            # Add channel dimension
                            img = np.expand_dims(img, axis=-1)
                            
                            # Create one-hot encoded label
                            label = tf.one_hot(encoded_labels[idx], num_classes)
                            
                            # Ensure types are correct
                            img = tf.cast(img, tf.float32)
                            label = tf.cast(label, tf.float32)
                        
                        yield img, label
                    except Exception as e:
                        print(f"Error processing {data_paths[idx]}: {e}")
                        continue
        
        # Create dataset with proper types
        dataset = tf.data.Dataset.from_generator(
            generator,
            output_signature=(
                tf.TensorSpec(shape=self.target_shape + (1,), dtype=tf.float32),
                tf.TensorSpec(shape=(num_classes,), dtype=tf.float32)
            )
        )
        
        # Optimize dataset for GPU training
        dataset = dataset.batch(batch_size, drop_remainder=True)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        dataset = dataset.cache()
        
        return dataset, label_encoder

def create_3d_cnn(input_shape, num_classes):
    """Memory-optimized 3D CNN with residual connections and attention mechanisms"""
    inputs = layers.Input(shape=input_shape)
    
    # Initial feature extraction with efficient bottleneck
    x = layers.Conv3D(32, kernel_size=1, padding='same', kernel_regularizer=l2(1e-4))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    
    # Progressive feature extraction with bottleneck blocks
    filter_sizes = [64, 128, 256, 512]
    for filters in filter_sizes:
        # Bottleneck block
        shortcut = x
        
        # Reduce channels
        x = layers.Conv3D(filters // 4, kernel_size=1, padding='same', kernel_regularizer=l2(1e-4))(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.1)(x)
        
        # Spatial convolution with reduced channels
        x = layers.Conv3D(filters // 4, kernel_size=3, padding='same', kernel_regularizer=l2(1e-4))(x)
        x = layers.BatchNormalization()(x)
        x = layers.LeakyReLU(alpha=0.1)(x)
        
        # Expand channels
        x = layers.Conv3D(filters, kernel_size=1, padding='same', kernel_regularizer=l2(1e-4))(x)
        x = layers.BatchNormalization()(x)
        
        # Project shortcut if needed
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv3D(filters, kernel_size=1, padding='same')(shortcut)
            shortcut = layers.BatchNormalization()(shortcut)
        
        # Add skip connection
        x = layers.Add()([x, shortcut])
        x = layers.LeakyReLU(alpha=0.1)(x)
        
        # Reduce spatial dimensions
        x = layers.MaxPooling3D(pool_size=2)(x)
    
    # Efficient global pooling
    x = layers.GlobalAveragePooling3D()(x)
    
    # Dropout for regularization
    x = layers.Dropout(0.5)(x)
    
    # Final classification
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    
    return models.Model(inputs=inputs, outputs=outputs)





def train_model():
    print("\nGPU Information:")
    print(tf.config.list_physical_devices('GPU'))
    
    data_gen = SingleScanDataGenerator(base_path='/kaggle/input/finaldata')
    train_data_paths, train_labels = data_gen.get_data(subset='TRAIN')
    test_data_paths, test_labels = data_gen.get_data(subset='TEST')
  
    # Initial setup code remains the same
    label_encoder = LabelEncoder()
    label_encoder.fit(train_labels + test_labels)  # Fit on both train and test labels
    class_counts = Counter(train_labels)  # Use train labels for class weights
    total_samples = len(train_labels)  # Use train labels for total samples
    
    num_classes = len(class_counts)
    class_weight_values = np.zeros(num_classes)
    for i in range(num_classes):
        count = class_counts[sorted(class_counts.keys())[i]]
        class_weight_values[i] = (1 / count) * (total_samples / num_classes)
    
    class_weights = tf.constant(class_weight_values, dtype=tf.float32)
    print("\nClass weights:", {i: w.numpy() for i, w in enumerate(class_weights)})
    
    # Data splitting and dataset creation remain the same
    batch_size = 1
    gradient_accumulation_steps = 8
    
    train_dataset, _ = data_gen.create_dataset(
        train_data_paths, train_labels, batch_size, augment=True, label_encoder=label_encoder
    )
    test_dataset, _ = data_gen.create_dataset(
        test_data_paths, test_labels, batch_size, augment=False, label_encoder=label_encoder
    )
    
    steps_per_epoch = min(len(train_data_paths) // (batch_size * gradient_accumulation_steps), 50)
    validation_steps = min(len(test_data_paths) // batch_size, 25)
    
    print(f"\nTraining with:")
    print(f"Steps per epoch: {steps_per_epoch}")
    print(f"Validation steps: {validation_steps}")
    print(f"Batch size: {batch_size}")
    print(f"Gradient accumulation steps: {gradient_accumulation_steps}")
    print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
    
    # Create model and optimizer
    model = create_3d_cnn((128, 128, 128, 1), num_classes)
    
    initial_learning_rate = 1e-4
    decay_steps = 10000
    lr_schedule = tf.keras.optimizers.schedules.CosineDecayRestarts(
        initial_learning_rate,
        decay_steps,
        t_mul=2.0,
        m_mul=0.9,
        alpha=1e-5
    )
    
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=lr_schedule,
        weight_decay=1e-4,
        beta_1=0.9,
        beta_2=0.999,
        epsilon=1e-7,
        clipnorm=1.0
    )
    
    # Create metrics
    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
    train_auc = tf.keras.metrics.AUC(name='train_auc')
    
    val_loss = tf.keras.metrics.Mean(name='val_loss')
    val_accuracy = tf.keras.metrics.CategoricalAccuracy(name='val_accuracy')
    val_auc = tf.keras.metrics.AUC(name='val_auc')
    
    # Define loss function
    def weighted_categorical_focal_loss(y_true, y_pred, gamma=2.0):
        y_true = tf.cast(y_true, tf.float32)
        y_pred = tf.cast(y_pred, tf.float32)
        
        epsilon = tf.keras.backend.epsilon()
        y_pred = tf.clip_by_value(y_pred, epsilon, 1.0 - epsilon)
        
        y_true_indices = tf.argmax(y_true, axis=-1)
        sample_weights = tf.cast(tf.gather(class_weights, y_true_indices), tf.float32)
        
        ce_loss = -y_true * tf.math.log(y_pred)
        pt = tf.exp(-ce_loss)
        focal_loss = tf.pow(1 - pt, gamma) * ce_loss
        weighted_focal_loss = focal_loss * tf.expand_dims(sample_weights, -1)
        
        return tf.reduce_mean(tf.reduce_sum(weighted_focal_loss, axis=-1))
    
    history = {'loss': [], 'accuracy': [], 'auc': [], 
               'val_loss': [], 'val_accuracy': [], 'val_auc': []}
    
    # Initialize best metrics for model saving
    best_val_accuracy = 0.0
    best_val_loss = float('inf')
    
    # Create directory for model checkpoints
    os.makedirs('model_checkpoints', exist_ok=True)
    
    # Training loop
    for epoch in range(100):
        print(f"\nEpoch {epoch + 1}")
        
        # Reset metrics
        train_loss.reset_state()
        train_accuracy.reset_state()
        train_auc.reset_state()
        val_loss.reset_state()
        val_accuracy.reset_state()
        val_auc.reset_state()
        
        # Training
        accumulated_gradients = None
        
        for step in range(steps_per_epoch * gradient_accumulation_steps):
            x_batch, y_batch = next(iter(train_dataset))
            
            with tf.GradientTape() as tape:
                predictions = model(x_batch, training=True)
                predictions = tf.cast(predictions, tf.float32)
                loss = weighted_categorical_focal_loss(y_batch, predictions)
                scaled_loss = loss / gradient_accumulation_steps
            
            gradients = tape.gradient(scaled_loss, model.trainable_variables)
            gradients = [tf.clip_by_norm(g, 1.0) if g is not None else g for g in gradients]
            
            # Initialize or accumulate gradients
            if accumulated_gradients is None:
                accumulated_gradients = [tf.Variable(g) for g in gradients]
            else:
                for i, g in enumerate(gradients):
                    accumulated_gradients[i].assign_add(g)
            
            # Update metrics
            train_loss.update_state(loss)
            train_accuracy.update_state(y_batch, predictions)
            train_auc.update_state(y_batch, predictions)
            
            # Apply accumulated gradients
            if (step + 1) % gradient_accumulation_steps == 0:
                optimizer.apply_gradients(zip(accumulated_gradients, model.trainable_variables))
                accumulated_gradients = None
        
        # Validation
        for x_val, y_val in test_dataset.take(validation_steps):
            val_predictions = model(x_val, training=False)
            val_predictions = tf.cast(val_predictions, tf.float32)
            v_loss = weighted_categorical_focal_loss(y_val, val_predictions)
            
            val_loss.update_state(v_loss)
            val_accuracy.update_state(y_val, val_predictions)
            val_auc.update_state(y_val, val_predictions)
        
        # Update history
        current_val_accuracy = val_accuracy.result().numpy()
        current_val_loss = val_loss.result().numpy()
        
        history['loss'].append(train_loss.result().numpy())
        history['accuracy'].append(train_accuracy.result().numpy())
        history['auc'].append(train_auc.result().numpy())
        history['val_loss'].append(current_val_loss)
        history['val_accuracy'].append(current_val_accuracy)
        history['val_auc'].append(val_auc.result().numpy())
        
        # Print current epoch metrics
        print(f"Loss: {train_loss.result():.4f}")
        print(f"Accuracy: {train_accuracy.result():.4f}")
        print(f"AUC: {train_auc.result():.4f}")
        print(f"Val Loss: {current_val_loss:.4f}")
        print(f"Val Accuracy: {current_val_accuracy:.4f}")
        print(f"Val AUC: {val_auc.result():.4f}")
        
        # Save model after every epoch
        epoch_model_path = f'model_checkpoints/model_epoch_{epoch+1}.keras'
        model.save(epoch_model_path)
        print(f"Saved epoch model: {epoch_model_path}")
        
        # Save best model based on validation accuracy
        if current_val_accuracy > best_val_accuracy:
            best_val_accuracy = current_val_accuracy
            model.save('model_checkpoints/best_model_accuracy.keras')
            print(f"\nNew best validation accuracy: {best_val_accuracy:.4f}")
            print("Saved best accuracy model")
            
        # Save best model based on validation loss
        if current_val_loss < best_val_loss:
            best_val_loss = current_val_loss
            model.save('model_checkpoints/best_model_loss.keras')
            print(f"\nNew best validation loss: {best_val_loss:.4f}")
            print("Saved best loss model")
        
        # Delete older epoch models to save space (keep only last 3 epochs)
        existing_epoch_models = sorted([f for f in os.listdir('model_checkpoints') 
                                     if f.startswith('model_epoch_')])
        if len(existing_epoch_models) > 3:
            oldest_model = os.path.join('model_checkpoints', existing_epoch_models[0])
            os.remove(oldest_model)
    
    # Save final model after training completion
    final_model_path = 'model_checkpoints/final_model.keras'
    model.save(final_model_path)
    print(f"\nFinal model saved at: {final_model_path}")
    
    # Create a summary file with model performances
    with open('model_checkpoints/model_summary.txt', 'w') as f:
        f.write(f"Training Summary\n")
        f.write(f"================\n")
        f.write(f"Best validation accuracy: {best_val_accuracy:.4f}\n")
        f.write(f"Best validation loss: {best_val_loss:.4f}\n")
        f.write(f"Final validation accuracy: {current_val_accuracy:.4f}\n")
        f.write(f"Final validation loss: {current_val_loss:.4f}\n")
    
    return model, label_encoder, history, test_dataset


if __name__ == "__main__":
    try:
        # Configure memory growth for GPU
        gpus = tf.config.list_physical_devices('GPU')
        if gpus:
            for gpu in gpus:
                try:
                    tf.config.experimental.set_memory_growth(gpu, True)
                except RuntimeError as e:
                    print(e)
        
        # Create results directory
        os.makedirs('results', exist_ok=True)
        
        # Train model
        model, label_encoder, history, test_dataset = train_model()
        
        # Save the model
        model.save('results/alzheimers_classification_model.keras')
        print("\nModel saved as 'results/alzheimers_classification_model.keras'")
                        
        def create_visualization_results(model, history, test_dataset, label_encoder, save_dir='./', validation_steps=25):
            """
            Create comprehensive visualizations for model analysis with progress tracking
            """
            print("\nGenerating visualizations...")
            
            # Create results directory if it doesn't exist
            os.makedirs(save_dir, exist_ok=True)
            
            print("1/5: Creating training history plots...")
            # Set style for better visualizations
            plt.style.use('seaborn-v0_8-darkgrid')
            
            # 1. Training History Plot
            plt.figure(figsize=(20, 10))
            
            # Plot training metrics
            plt.subplot(2, 2, 1)
            plt.plot(history['accuracy'], label='Training Accuracy', linewidth=2)
            plt.plot(history['val_accuracy'], label='Validation Accuracy', linewidth=2)
            plt.title('Model Accuracy', fontsize=14, pad=10)
            plt.xlabel('Epoch', fontsize=12)
            plt.ylabel('Accuracy', fontsize=12)
            plt.legend(fontsize=10)
            plt.grid(True, linestyle='--', alpha=0.7)
            
            plt.subplot(2, 2, 2)
            plt.plot(history['loss'], label='Training Loss', linewidth=2)
            plt.plot(history['val_loss'], label='Validation Loss', linewidth=2)
            plt.title('Model Loss', fontsize=14, pad=10)
            plt.xlabel('Epoch', fontsize=12)
            plt.ylabel('Loss', fontsize=12)
            plt.legend(fontsize=10)
            plt.grid(True, linestyle='--', alpha=0.7)
            
            plt.subplot(2, 2, 3)
            plt.plot(history['auc'], label='Training AUC', linewidth=2)
            plt.plot(history['val_auc'], label='Validation AUC', linewidth=2)
            plt.title('Model AUC', fontsize=14, pad=10)
            plt.xlabel('Epoch', fontsize=12)
            plt.ylabel('AUC', fontsize=12)
            plt.legend(fontsize=10)
            plt.grid(True, linestyle='--', alpha=0.7)
            
            # Add learning curve analysis
            plt.subplot(2, 2, 4)
            train_sizes = np.linspace(0.1, 1.0, len(history['accuracy']))
            plt.plot(train_sizes, history['accuracy'], 'o-', label='Training Accuracy', linewidth=2)
            plt.plot(train_sizes, history['val_accuracy'], 'o-', label='Validation Accuracy', linewidth=2)
            plt.title('Learning Curve Analysis', fontsize=14, pad=10)
            plt.xlabel('Training Set Size (fraction)', fontsize=12)
            plt.ylabel('Accuracy', fontsize=12)
            plt.legend(fontsize=10)
            plt.grid(True, linestyle='--', alpha=0.7)
            
            plt.tight_layout()
            plt.savefig(f'{save_dir}/training_metrics.png', dpi=300, bbox_inches='tight')
            plt.close()
            
            print("2/5: Making predictions on test dataset...")
            # Get predictions on test set - limiting to validation_steps
            test_batches = test_dataset.take(validation_steps)
            y_pred_proba_list = []
            y_true_list = []
            
            print("Processing test batches...")
            for i, (x_batch, y_batch) in enumerate(test_batches):
                if i >= validation_steps:
                    break
                batch_pred = model.predict(x_batch, verbose=0)
                y_pred_proba_list.append(batch_pred)
                y_true_list.append(np.argmax(y_batch, axis=1))
                if (i + 1) % 5 == 0:
                    print(f"Processed {i + 1}/{validation_steps} batches")
            
            y_pred_proba = np.vstack(y_pred_proba_list)
            y_true = np.concatenate(y_true_list)
            y_pred = np.argmax(y_pred_proba, axis=1)
            
            print("3/5: Generating ROC curve...")
            # Plot ROC Curve
            plt.figure(figsize=(10, 10))
            fpr = dict()
            tpr = dict()
            roc_auc = dict()
            
            for i in range(len(label_encoder.classes_)):
                y_true_binary = (y_true == i).astype(int)
                y_score = y_pred_proba[:, i]
                fpr[i], tpr[i], _ = roc_curve(y_true_binary, y_score)
                roc_auc[i] = auc(fpr[i], tpr[i])
                
                plt.plot(fpr[i], tpr[i], label=f'{label_encoder.classes_[i]} (AUC = {roc_auc[i]:.2f})')
            
            plt.plot([0, 1], [0, 1], 'k--', label='Random (AUC = 0.50)')
            plt.xlabel('False Positive Rate', fontsize=12)
            plt.ylabel('True Positive Rate', fontsize=12)
            plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=14)
            plt.legend(fontsize=10)
            plt.grid(True, alpha=0.3)
            plt.savefig(f'{save_dir}/roc_curve.png', dpi=300, bbox_inches='tight')
            plt.close()
            
            print("4/5: Creating confusion matrix...")
            # Enhanced Confusion Matrix
            plt.figure(figsize=(12, 8))
            cm = confusion_matrix(y_true, y_pred)
            cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            
            sns.heatmap(cm_normalized, annot=cm, fmt='d', cmap='Blues',
                        xticklabels=label_encoder.classes_,
                        yticklabels=label_encoder.classes_,
                        annot_kws={'size': 12})
            plt.title('Confusion Matrix\n(Numbers: Raw Counts, Colors: Normalized)', fontsize=14, pad=20)
            plt.xlabel('Predicted Label', fontsize=12)
            plt.ylabel('True Label', fontsize=12)
            plt.tight_layout()
            plt.savefig(f'{save_dir}/confusion_matrix_enhanced.png', dpi=300, bbox_inches='tight')
            plt.close()
            
            print("5/5: Analyzing training progress...")
            # Training Progress Analysis
            plt.figure(figsize=(15, 5))
            epochs = range(1, len(history['accuracy']) + 1)
            
            plt.subplot(1, 2, 1)
            acc_improvement = np.diff(history['val_accuracy'])
            plt.plot(epochs[1:], acc_improvement, 'b-', label='Validation Accuracy Improvement')
            plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
            plt.title('Validation Accuracy Improvement per Epoch', fontsize=12)
            plt.xlabel('Epoch', fontsize=10)
            plt.ylabel('Accuracy Improvement', fontsize=10)
            plt.legend()
            
            plt.subplot(1, 2, 2)
            loss_improvement = -np.diff(history['val_loss'])
            plt.plot(epochs[1:], loss_improvement, 'g-', label='Validation Loss Improvement')
            plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
            plt.title('Validation Loss Improvement per Epoch', fontsize=12)
            plt.xlabel('Epoch', fontsize=10)
            plt.ylabel('Loss Improvement', fontsize=10)
            plt.legend()
            
            plt.tight_layout()
            plt.savefig(f'{save_dir}/training_progress_analysis.png', dpi=300, bbox_inches='tight')
            plt.close()
            
            # Create and save performance summary
            performance_summary = {
                'Final Training Accuracy': history['accuracy'][-1],
                'Final Validation Accuracy': history['val_accuracy'][-1],
                'Best Validation Accuracy': max(history['val_accuracy']),
                'Final Training Loss': history['loss'][-1],
                'Final Validation Loss': history['val_loss'][-1],
                'Best Validation Loss': min(history['val_loss']),
                'Final Training AUC': history['auc'][-1],
                'Final Validation AUC': history['val_auc'][-1],
                'Best Validation AUC': max(history['val_auc']),
                'Number of Epochs': len(history['accuracy'])
            }
            
            # Save performance summary
            with open(f'{save_dir}/performance_summary.txt', 'w') as f:
                f.write('Model Performance Summary\n')
                f.write('=======================\n\n')
                for metric, value in performance_summary.items():
                    f.write(f'{metric}: {value:.4f}\n')
            
            print(f"\nAll visualizations have been saved to: {save_dir}")
            return performance_summary
        # Create visualizations and get performance summary
        performance_summary = create_visualization_results(
            model=model,
            history=history,
            test_dataset=test_dataset,
            label_encoder=label_encoder,
            save_dir='./results'
        )

        # Print summary to console
        print("\nModel Performance Summary:")
        print("=======================")
        for metric, value in performance_summary.items():
            print(f"{metric}: {value:.4f}")
            
        # Save model summary
        with open('results/model_summary.txt', 'w') as f:
            model.summary(print_fn=lambda x: f.write(x + '\n'))
        print("\nModel summary saved as 'results/model_summary.txt'")
        
    except Exception as e:
        print(f"\nAn error occurred during training/evaluation: {str(e)}")
        import traceback
        traceback.print_exc()
        raise e