In [None]:
# ====================== IMPORTS ======================
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # Suppress TensorFlow warnings
import logging
logging.getLogger('tensorflow').disabled = True

import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.applications import EfficientNetB3
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, EarlyStopping, ModelCheckpoint
from tensorflow.keras.regularizers import l2
from tensorflow.keras.optimizers.schedules import CosineDecay
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from sklearn.metrics import classification_report, confusion_matrix, precision_recall_curve, average_precision_score
from lime import lime_image
from lime.wrappers.scikit_image import SegmentationAlgorithm
from skimage.segmentation import mark_boundaries
import warnings
from pathlib import Path
import cv2

warnings.filterwarnings('ignore')

# ====================== CUSTOM LAYERS ======================
class ChannelAttention(Layer):
    def __init__(self, ratio=8, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        self.ratio = ratio

    def build(self, input_shape):
        self.channels = input_shape[-1]
        self.shared_dense = Dense(self.channels//self.ratio, activation='relu', 
                                kernel_initializer='he_normal', kernel_regularizer=l2(1e-3))
        self.attention_dense = Dense(self.channels, activation='sigmoid', 
                                   kernel_initializer='he_normal', kernel_regularizer=l2(1e-3))
        super(ChannelAttention, self).build(input_shape)

    def call(self, inputs):
        gap = tf.reduce_mean(inputs, axis=[1,2], keepdims=True)
        gmp = tf.reduce_max(inputs, axis=[1,2], keepdims=True)
        attention = self.attention_dense(self.shared_dense(gap) + self.shared_dense(gmp))
        return inputs * attention

    def get_config(self):
        config = super(ChannelAttention, self).get_config()
        config.update({'ratio': self.ratio})
        return config

class GNNFusionLayer(Layer):
    def __init__(self, units=128, **kwargs):
        super(GNNFusionLayer, self).__init__(**kwargs)
        self.units = units
        
    def build(self, input_shape):
        self.dense = Dense(self.units, activation='relu', kernel_regularizer=l2(1e-3))
        super(GNNFusionLayer, self).build(input_shape)
        
    def call(self, inputs):
        x = tf.squeeze(inputs, axis=1)
        x = self.dense(x)
        return tf.expand_dims(x, axis=1)

# ====================== DATA PIPELINE WITHOUT AUGMENTATION ======================
def create_datasets(train_dir, valid_dir, img_size=128, batch_size=32):
    class_names = sorted([d.name for d in Path(train_dir).iterdir() if d.is_dir()])
    class_counts = {i: len(list((Path(train_dir)/name).glob('*'))) 
                   for i, name in enumerate(class_names)}
    median_freq = np.median(list(class_counts.values()))
    class_weights = {i: median_freq/count for i, count in class_counts.items()}

    train_ds = tf.keras.utils.image_dataset_from_directory(
        train_dir,
        label_mode='categorical',
        image_size=(img_size, img_size),
        batch_size=batch_size,
        shuffle=True
    )

    val_ds = tf.keras.utils.image_dataset_from_directory(
        valid_dir,
        label_mode='categorical',
        image_size=(img_size, img_size),
        batch_size=batch_size,
        shuffle=False
    )

    preprocess_fn = tf.keras.applications.efficientnet.preprocess_input
    train_ds = train_ds.map(lambda x, y: (preprocess_fn(x), y),
                         num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
    val_ds = val_ds.map(lambda x, y: (preprocess_fn(x), y),
                      num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

    return train_ds, val_ds, class_names, class_weights

# ====================== REGULARIZED MODEL ARCHITECTURE ======================
def build_model(input_shape=(128, 128, 3), num_classes=5):
    inputs = Input(shape=input_shape)
    
    # CNN Stream with increased regularization
    x = Conv2D(32, 3, padding='same', kernel_regularizer=l2(1e-3), name='conv2d_1')(inputs)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(32, 3, kernel_regularizer=l2(1e-3), name='conv2d_2')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = MaxPool2D(2)(x)
    x = ChannelAttention()(x)
    x = Dropout(0.4)(x)

    x = Conv2D(64, 3, padding='same', kernel_regularizer=l2(1e-3), name='conv2d_3')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(64, 3, kernel_regularizer=l2(1e-3), name='conv2d_4')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = MaxPool2D(2)(x)
    x = ChannelAttention()(x)
    x = Dropout(0.4)(x)

    x = Conv2D(128, 3, padding='same', kernel_regularizer=l2(1e-3), name='conv2d_5')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = Dropout(0.3)(x)

    x = Conv2D(128, 3, kernel_regularizer=l2(1e-3), name='conv2d_6')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = MaxPool2D(2)(x)
    cnn_out = ChannelAttention()(x)
    x = Dropout(0.5)(x)

    # EfficientNet Stream
    effnet = EfficientNetB3(include_top=False, weights='imagenet', input_tensor=inputs)
    effnet.trainable = False

    # Feature Fusion
    cnn_pool = GlobalAveragePooling2D()(cnn_out)
    effnet_pool = GlobalAveragePooling2D()(effnet.output)
    fused = Concatenate()([cnn_pool, effnet_pool])

    # GNN Processing
    graph_nodes = Dense(256, kernel_regularizer=l2(1e-3))(fused)
    graph_nodes = Reshape((1, 256))(graph_nodes)
    
    gn_out = GNNFusionLayer(128)(graph_nodes)
    gn_out = Dropout(0.4)(gn_out)
    gn_out = GNNFusionLayer(64)(gn_out)
    gn_out = Dropout(0.4)(gn_out)
    gn_out = Flatten()(gn_out)

    # Classification Head with higher dropout
    x = Dense(512, kernel_regularizer=l2(1e-3))(gn_out)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.1)(x)
    x = Dropout(0.6)(x)
    outputs = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs, outputs, name="Regularized_Hybrid_CNN_EffNet_GNN")
    return model

# ====================== GRAD-CAM ANALYSIS ======================
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    # Create a model that maps the input image to the activations
    # of the last conv layer as well as the output predictions
    grad_model = Model(
        inputs=model.inputs,
        outputs=[model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Compute the gradient of the top predicted class for our input image
    # with respect to the activations of the last conv layer
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron (top predicted or chosen)
    # with regard to the output feature map of the last conv layer
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # This is a vector where each entry is the mean intensity of the gradient
    # over a specific feature map channel
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # We multiply each channel in the feature map array
    # by "how important this channel is" with regard to the top predicted class
    # then sum all the channels to obtain the heatmap class activation
    last_conv_layer_output = last_conv_layer_output[0]
    heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis]
    heatmap = tf.squeeze(heatmap)

    # For visualization purpose, we will also normalize the heatmap between 0 & 1
    heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap)
    return heatmap.numpy()

class GradCAMCallback(Callback):
    def __init__(self, image_path, model, img_size=128, last_conv_layer_name='conv2d_6'):
        super().__init__()
        self.image_path = image_path
        self.img_size = img_size
        self.last_conv_layer_name = last_conv_layer_name
        self._model = model
        self.history = []
        os.makedirs('gradcam_analysis', exist_ok=True)
        
        # Load and preprocess sample image
        img = tf.keras.preprocessing.image.load_img(image_path, target_size=(img_size, img_size))
        self.sample_image = tf.keras.applications.efficientnet.preprocess_input(
            tf.keras.preprocessing.image.img_to_array(img)
        )
        self.sample_image = np.expand_dims(self.sample_image, axis=0)

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, value):
        self._model = value

    def on_epoch_end(self, epoch, logs=None):
        try:
            # Generate heatmap
            heatmap = make_gradcam_heatmap(
                self.sample_image, 
                self.model, 
                self.last_conv_layer_name
            )
            
            # Rescale heatmap to a range 0-255
            heatmap = np.uint8(255 * heatmap)
            
            # Use jet colormap to colorize heatmap
            jet = plt.cm.get_cmap("jet")
            
            # Use RGB values of the colormap
            jet_colors = jet(np.arange(256))[:, :3]
            jet_heatmap = jet_colors[heatmap]
            
            # Create an image with RGB colorized heatmap
            jet_heatmap = tf.keras.preprocessing.image.array_to_img(jet_heatmap)
            jet_heatmap = jet_heatmap.resize((self.img_size, self.img_size))
            jet_heatmap = tf.keras.preprocessing.image.img_to_array(jet_heatmap)
            
            # Superimpose the heatmap on original image
            original_img = tf.keras.preprocessing.image.array_to_img(self.sample_image[0])
            original_img = original_img.resize((self.img_size, self.img_size))
            original_img = tf.keras.preprocessing.image.img_to_array(original_img)
            
            superimposed_img = jet_heatmap * 0.4 + original_img
            superimposed_img = tf.keras.preprocessing.image.array_to_img(superimposed_img)
            
            # Get prediction
            pred = self.model.predict(self.sample_image, verbose=0)
            
            # Save visualization
            plt.figure(figsize=(10, 5))
            
            plt.subplot(1, 2, 1)
            plt.imshow(original_img / 255.0)
            plt.title("Original Image")
            plt.axis('off')
            
            plt.subplot(1, 2, 2)
            plt.imshow(superimposed_img)
            plt.title(f"Grad-CAM (Epoch {epoch+1})")
            plt.axis('off')
            
            plt.tight_layout()
            filename = f"gradcam_analysis/{self.model.name}_epoch_{epoch+1}.png"
            plt.savefig(filename)
            plt.close()
            
            # Store in history
            self.history.append({
                'epoch': epoch+1,
                'heatmap': heatmap,
                'superimposed_img': superimposed_img,
                'prediction': pred
            })
            
        except Exception as e:
            print(f"Error generating Grad-CAM at epoch {epoch+1}: {str(e)}")

# ====================== IMPROVED LIME CALLBACK ======================
class LimeExplainer(Callback):
    def __init__(self, image_path, class_names, img_size=128, model=None):
        super().__init__()
        self.class_names = class_names
        self.explainer = lime_image.LimeImageExplainer()
        self.segmenter = SegmentationAlgorithm('quickshift', kernel_size=1, max_dist=200, ratio=0.2)
        self._model = model
        self.img_size = img_size
        
        if not os.path.exists(image_path):
            raise FileNotFoundError(f"Image not found at {image_path}")
            
        img = tf.keras.preprocessing.image.load_img(image_path, target_size=(img_size, img_size))
        self.sample_image = tf.keras.applications.efficientnet.preprocess_input(
            tf.keras.preprocessing.image.img_to_array(img)
        )
        self.history = []
        os.makedirs('lime_explanations', exist_ok=True)

    @property
    def model(self):
        return self._model

    @model.setter
    def model(self, value):
        self._model = value

    def on_epoch_end(self, epoch, logs=None):
        try:
            img_to_explain = (self.sample_image * 0.5 + 0.5).astype(np.uint8)
            
            def predict_fn(images):
                processed = tf.keras.applications.efficientnet.preprocess_input(images.copy())
                return self.model.predict(processed, verbose=0)
            
            explanation = self.explainer.explain_instance(
                img_to_explain,
                predict_fn,
                top_labels=5,
                hide_color=0,
                num_samples=1000,
                segmentation_fn=self.segmenter
            )
            
            temp, mask = explanation.get_image_and_mask(
                explanation.top_labels[0],
                positive_only=True,
                num_features=5,
                hide_rest=False
            )
            
            pred = self.model.predict(np.expand_dims(self.sample_image, axis=0), verbose=0)
            self.history.append({
                'epoch': epoch+1,
                'image': temp,
                'mask': mask,
                'prediction': self.class_names[np.argmax(pred)],
                'confidence': float(np.max(pred))
            })
            
            self._save_epoch_visualization(epoch+1, img_to_explain, temp, mask, pred)
            
        except Exception as e:
            print(f"Error generating LIME explanation at epoch {epoch+1}: {str(e)}")

    def _save_epoch_visualization(self, epoch, original_img, explanation_img, mask, pred):
        plt.figure(figsize=(12, 6))
        
        plt.subplot(1, 2, 1)
        plt.imshow(original_img)
        plt.title(f"Original Image (Epoch {epoch})")
        plt.axis('off')
        
        plt.subplot(1, 2, 2)
        plt.imshow(mark_boundaries(explanation_img, mask))
        plt.title(f"Pred: {self.class_names[np.argmax(pred)]} ({np.max(pred):.2f})")
        plt.axis('off')
        
        plt.tight_layout()
        filename = f"lime_explanations/{self.model.name}_epoch_{epoch}.png"
        plt.savefig(filename)
        plt.close()

    def on_train_end(self, logs=None):
        if not self.history:
            print("No LIME explanations were generated during training")
            return
            
        plt.figure(figsize=(20, 5*len(self.history)))
        for idx, item in enumerate(self.history):
            original_img = (self.sample_image * 0.5 + 0.5).astype(np.uint8)
            
            plt.subplot(len(self.history), 2, 2*idx+1)
            plt.imshow(original_img)
            plt.title(f"Original - Epoch {item['epoch']}")
            plt.axis('off')
            
            plt.subplot(len(self.history), 2, 2*idx+2)
            plt.imshow(mark_boundaries(item['image'], item['mask']))
            plt.title(f"Epoch {item['epoch']}: {item['prediction']} ({item['confidence']:.2f})")
            plt.axis('off')
        
        plt.tight_layout()
        summary_filename = f"lime_explanations/{self.model.name}_summary.png"
        plt.savefig(summary_filename)
        plt.close()

# ====================== UTILITY FUNCTIONS ======================
def save_model_with_architecture(model, base_filename):
    architecture_name = model.name
    filename = f"{base_filename}_{architecture_name}.keras"
    model.save(filename)
    print(f"Model saved as: {filename}")
    return filename

def save_training_history(history, model):
    history_filename = f"training_history_{model.name}.json"
    with open(history_filename, 'w') as f:
        json.dump(history.history, f, indent=4)
    print(f"Training history saved as: {history_filename}")
    return history_filename

def save_classification_report(y_true, y_pred, class_names, model):
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    report_filename = f"classification_report_{model.name}.json"
    with open(report_filename, 'w') as f:
        json.dump(report, f, indent=4)
    print(f"Classification report saved as: {report_filename}")
    return report

def plot_pr_curves(y_true, y_pred_probs, class_names, model_name):
    plt.figure(figsize=(10, 8))
    for i in range(len(class_names)):
        precision, recall, _ = precision_recall_curve(y_true[:, i], y_pred_probs[:, i])
        ap = average_precision_score(y_true[:, i], y_pred_probs[:, i])
        plt.plot(recall, precision, lw=2, 
                 label=f'{class_names[i]} (AP={ap:.2f})')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curves - {model_name}')
    plt.legend()
    plt.savefig(f'pr_curves_{model_name}.png')
    plt.close()

def plot_training_metrics(history, model_name):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title(f'{model_name} Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'{model_name} Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')
    
    plt.tight_layout()
    filename = f"training_metrics_{model_name}.png"
    plt.savefig(filename)
    plt.close()
    return filename

# ====================== MAIN TRAINING FUNCTION ======================
def main():
    # Configuration
    train_dir = r"C:\Users\G.SAI\Desktop\skin_lesion_research_dataset\Data_set\Train"
    valid_dir = r"C:\Users\G.SAI\Desktop\skin_lesion_research_dataset\Data_set\Val"
    img_size = 128
    batch_size = 32
    epochs = 15 # Increased for early stopping
    
    # Create datasets
    train_ds, val_ds, class_names, class_weights = create_datasets(train_dir, valid_dir, img_size, batch_size)
    
    # Build model
    model = build_model(input_shape=(img_size, img_size, 3), num_classes=len(class_names))
    
    # Optimizer with lower learning rate
    optimizer = tf.keras.optimizers.Adam(
        CosineDecay(
            initial_learning_rate=1e-4,  # Reduced from 3e-4
            decay_steps=len(train_ds)*epochs
        )
    )

    # Compile with increased label smoothing
    model.compile(
        optimizer=optimizer,
        loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.3),  # Increased smoothing
        metrics=['accuracy',
                tf.keras.metrics.Precision(name='precision'),
                tf.keras.metrics.Recall(name='recall'),
                tf.keras.metrics.AUC(name='auc'),
                tf.keras.metrics.AUC(name='pr_auc', curve='PR')]  # Added PR AUC
    )

    # Sample image paths for visualization callbacks
    sample_image_path = r"C:\Users\G.SAI\Desktop\M_data\my_data\train\Monkeypox\MKP_09_03_9.jpg"
    
    # Enhanced callbacks
    callbacks = [
        EarlyStopping(monitor='val_pr_auc', patience=10, mode='max', restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, verbose=1),
        ModelCheckpoint(f'best_model_{model.name}.keras', save_best_only=True, monitor='val_pr_auc', mode='max'),
        LimeExplainer(sample_image_path, class_names, img_size, model),
        GradCAMCallback(sample_image_path, model, img_size, 'conv2d_6')  # Changed to use the last conv layer in our custom CNN
    ]

    # Training
    print(f"Starting training for model: {model.name}")
    history = model.fit(
        train_ds,
        epochs=epochs,
        validation_data=val_ds,
        class_weight=class_weights,
        callbacks=callbacks,
        verbose=1
    )

    # Evaluation
    print("\nFinal Evaluation:")
    best_model = tf.keras.models.load_model(f'best_model_{model.name}.keras', custom_objects={
        'ChannelAttention': ChannelAttention,
        'GNNFusionLayer': GNNFusionLayer
    })
    results = best_model.evaluate(val_ds, verbose=0)
    print(f"Validation Accuracy: {results[1]:.4f}")
    print(f"Validation AUC: {results[4]:.4f}")
    print(f"Validation PR-AUC: {results[5]:.4f}")

    # Generate predictions
    y_true = np.concatenate([y for x, y in val_ds], axis=0)
    y_pred_probs = best_model.predict(val_ds, verbose=0)
    y_pred_labels = np.argmax(y_pred_probs, axis=1)
    y_true_labels = np.argmax(y_true, axis=1)

    # Confusion Matrix
    plt.figure(figsize=(10,8))
    sns.heatmap(confusion_matrix(y_true_labels, y_pred_labels), 
                annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, 
                yticklabels=class_names)
    plt.title(f'Confusion Matrix - {best_model.name}')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.savefig(f'confusion_matrix_{best_model.name}.png')
    plt.close()

    # Classification Report
    print("\nClassification Report:")
    report = save_classification_report(y_true_labels, y_pred_labels, class_names, best_model)
    print(classification_report(y_true_labels, y_pred_labels, target_names=class_names))

    # Precision-Recall Curves
    plot_pr_curves(y_true, y_pred_probs, class_names, best_model.name)

    # Training metrics plots
    plot_training_metrics(history, model.name)

    # Save models
    best_model_path = save_model_with_architecture(best_model, 'best_model')
    final_model_path = save_model_with_architecture(model, 'final_model')

    # Save training history
    history_path = save_training_history(history, model)

    print("\nTraining completed. All outputs saved with model architecture names:")
    print(f"- Best model: {best_model_path}")
    print(f"- Final model: {final_model_path}")
    print(f"- Training metrics: training_metrics_{model.name}.png")
    print(f"- Confusion matrix: confusion_matrix_{best_model.name}.png")
    print(f"- PR curves: pr_curves_{best_model.name}.png")
    print(f"- Classification report: classification_report_{best_model.name}.json")
    print(f"- Training history: {history_path}")
    print(f"- LIME explanations saved in lime_explanations/ directory")
    print(f"- Grad-CAM analysis saved in gradcam_analysis/ directory")

if __name__ == "__main__":
    main()