In [1]:
pip install opencv-python numpy matplotlib scikit-learn tensorflow scikit-image torch torchvision


Note: you may need to restart the kernel to use updated packages.


In [2]:
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from skimage.segmentation import slic
from skimage.color import label2rgb, rgb2lab
from tensorflow.keras.applications import EfficientNetV2B0
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, concatenate, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import AdamW
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import tensorflow as tf

print("All imports successful!")


2025-03-03 17:51:49.735271: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-03 17:51:49.736908: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-03 17:51:49.768956: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-03 17:51:49.769648: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


All imports successful!


# Cell 2: Configuration and Constants
This cell defines constants and paths, making it easy to adjust parameters.


In [3]:
DATASET_ROOT = "/home/w2sg-arnav/msusir/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection"
ORIGINAL_DIR = os.path.join(DATASET_ROOT, "Cotton Leaf Disease Detection Dataset", "Original Dataset")
AUGMENTED_DIR = os.path.join(DATASET_ROOT, "Cotton Leaf Disease Detection Dataset", "Augmented Dataset")

IMAGE_SIZE = (224, 224)  # Adjust as needed for your models (e.g., 384x384 for ViT)
CLASSES = [
    "Bacterial Blight",
    "Curl Virus",
    "Healthy Leaf",
    "Herbicide Growth Damage",
    "Leaf Hopper Jassids",
    "Leaf Redding",
    "Leaf Variegation",
]
NUM_CLASSES = len(CLASSES)
BATCH_SIZE = 32  # Adjust based on your GPU memory
EPOCHS = 30      # Adjust as needed
LEARNING_RATE = 1e-4


# Cell 4: Data Loading Functions

Functions to load and organize the image data and labels.

In [4]:
def load_data(data_dir, use_augmentation=False):
    """Loads images and labels from a directory structure."""
    image_files = []
    labels = []

    for class_name in CLASSES:
        class_dir = os.path.join(data_dir, class_name)
        for filename in os.listdir(class_dir):
            if filename.endswith((".jpg", ".jpeg", ".png")):
                image_files.append(os.path.join(class_dir, filename))
                labels.append(class_name)  # Store class name as string

    return image_files, labels


def create_tf_dataset(image_files, labels, preprocess_fn, batch_size=BATCH_SIZE):
    """Creates a TensorFlow dataset for efficient data loading."""
    dataset = tf.data.Dataset.from_tensor_slices((image_files, labels))

    def _load_and_preprocess(image_path, label):
        image_string = tf.io.read_file(image_path)
        image = tf.image.decode_image(image_string, channels=3, expand_animations=False) # Handle potential errors
        image = tf.image.convert_image_dtype(image, tf.float32)
        [image, ] = tf.py_function(preprocess_fn, [image_path], [tf.float32])
        image.set_shape([None, None, 3])
        return image, label

    dataset = dataset.map(_load_and_preprocess)
    dataset = dataset.batch(batch_size)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)
    return dataset


# Cell 6: Preprocessing Functions

These functions perform image segmentation, noise reduction, and orientation normalization.  These are *examples* and should be refined/replaced with more robust methods as needed.

In [5]:
def segment_leaf(image):
    """Segments the leaf from the background using SLIC."""
    image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2Lab)
    segments = slic(image_lab, n_segments=100, compactness=10, sigma=1, start_label=1)
    largest_segment_label = np.argmax(np.bincount(segments.ravel())[1:]) + 1
    leaf_mask = (segments == largest_segment_label).astype(np.uint8)
    segmented_image = cv2.bitwise_and(image, image, mask=leaf_mask)
    return segmented_image, leaf_mask

def reduce_noise(image):
    """Applies Non-Local Means Denoising."""
    denoised_image = cv2.fastNlMeansDenoisingColored(image, None, 10, 10, 7, 21)
    return denoised_image

def normalize_orientation(image, mask):
    """Basic orientation normalization."""
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if contours:
        largest_contour = max(contours, key=cv2.contourArea)
        if len(largest_contour) >= 5:
            ellipse = cv2.fitEllipse(largest_contour)
            angle = ellipse[2]
            (h, w) = image.shape[:2]
            center = (w // 2, h // 2)
            rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
            rotated_image = cv2.warpAffine(image, rotation_matrix, (w, h))
            rotated_mask = cv2.warpAffine(mask, rotation_matrix, (w, h))
            return rotated_image, rotated_mask
    return image, mask

def preprocess_image(image_path):
    """Loads, preprocesses, and returns a single image."""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    segmented_image, leaf_mask = segment_leaf(image)
    denoised_image = reduce_noise(segmented_image)
    normalized_image, _ = normalize_orientation(denoised_image, leaf_mask)
    resized_image = cv2.resize(normalized_image, IMAGE_SIZE)
    final_image = resized_image.astype(np.float32) / 255.0
    return final_image


# Cell 8: Data Loading and Splitting

This cell loads the data, splits it into training, validation, and test sets, and converts labels to the correct format.

In [6]:
image_files, labels = load_data(ORIGINAL_DIR)  # Or AUGMENTED_DIR

train_files, test_files, train_labels, test_labels = train_test_split(
    image_files, labels, test_size=0.2, random_state=42, stratify=labels
)
train_files, val_files, train_labels, val_labels = train_test_split(
    train_files, train_labels, test_size=0.25, random_state=42, stratify=train_labels
)

# Convert labels to numerical format
label_to_index = {label: i for i, label in enumerate(CLASSES)}
train_labels_int = np.array([label_to_index[label] for label in train_labels])
val_labels_int = np.array([label_to_index[label] for label in val_labels])
test_labels_int = np.array([label_to_index[label] for label in test_labels])

# One-hot encode
train_labels_onehot = to_categorical(train_labels_int, num_classes=NUM_CLASSES)
val_labels_onehot = to_categorical(val_labels_int, num_classes=NUM_CLASSES)
test_labels_onehot = to_categorical(test_labels_int, num_classes=NUM_CLASSES)

# Create TensorFlow datasets
train_ds = create_tf_dataset(train_files, train_labels_onehot, preprocess_image)
val_ds = create_tf_dataset(val_files, val_labels_onehot, preprocess_image)
test_ds = create_tf_dataset(test_files, test_labels_onehot, preprocess_image)


FileNotFoundError: [Errno 2] No such file or directory: '/home/w2sg-arnav/msusir/SAR-CLD-2024 A Comprehensive Dataset for Cotton Leaf Disease Detection/Cotton Leaf Disease Detection Dataset/Original Dataset/Bacterial Blight'

# Cell 10:  Data Visualization (Optional)

This cell displays a few preprocessed images to visually check the data.

In [None]:
# Cell 11: Visualize some preprocessed images
plt.figure(figsize=(12, 6))
for i in range(min(5, len(train_files))):  # Display up to 5 images
    image = preprocess_image(train_files[i])  # Load and preprocess
    plt.subplot(1, 5, i + 1)
    plt.imshow(image)
    plt.title(train_labels[i])
    plt.axis("off")
plt.show()

# Cell 12: Model Definition (ViT and CNN Ensemble)

This cell defines the Vision Transformer (ViT) and CNN (EfficientNetV2B0) models and creates an ensemble.

In [None]:
# Cell 13: Model Definition
def create_vit_model(input_shape=IMAGE_SIZE + (3,), num_classes=NUM_CLASSES):
    vit_model = tf.keras.applications.vit.ViT(
        input_shape=input_shape,
        num_classes=num_classes,
        classifier_activation='softmax',
        include_top=True,
        include_preprocessing=False # Important: We do our own preprocessing
    )
    return vit_model


def create_cnn_model(input_shape=IMAGE_SIZE + (3,), num_classes=NUM_CLASSES):
    base_model = EfficientNetV2B0(input_shape=input_shape, include_top=False, weights='imagenet')
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x) # Add dropout for regularization
    predictions = Dense(num_classes, activation='softmax')(x)
    model = Model(inputs=base_model.input, outputs=predictions)
    return model

def create_ensemble(vit_model, cnn_model, input_shape=IMAGE_SIZE+(3,)):
    vit_input = Input(shape=input_shape)
    cnn_input = Input(shape=input_shape)

    vit_output = vit_model(vit_input)
    cnn_output = cnn_model(cnn_input)

    # Simple Averaging Ensemble
    merged = concatenate([vit_output, cnn_output])
    # Optionally add a dense layer after concatenation for further processing
    # merged = Dense(512, activation='relu')(merged)
    output = Dense(NUM_CLASSES, activation='softmax')(merged) # Combine predictions

    ensemble_model = Model(inputs=[vit_input, cnn_input], outputs=output)
    return ensemble_model


# Create the individual models
vit_model = create_vit_model()
cnn_model = create_cnn_model()
# Create the ensemble model
ensemble_model = create_ensemble(vit_model, cnn_model)

ensemble_model.compile(optimizer=AdamW(learning_rate=LEARNING_RATE),
                      loss=CategoricalCrossentropy(),
                      metrics=['accuracy'])

ensemble_model.summary()

In [None]:
import official.vision.configs.backbones as backbones
print(dir(backbones))

['Backbone', 'DilatedResNet', 'EfficientNet', 'List', 'MobileDet', 'MobileNet', 'Optional', 'ResNet', 'RevNet', 'SpineNet', 'SpineNetMobile', 'Transformer', 'Tuple', 'VisionTransformer', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__spec__', 'dataclasses', 'hyperparams']


# Cell 14:  Training Callbacks

This cell sets up callbacks for model checkpointing, early stopping, and learning rate scheduling.

In [None]:
# Cell 15: Training Callbacks
checkpoint_filepath = 'best_ensemble_model.h5'  # Save the best model
model_checkpoint_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
    monitor='val_accuracy',
    mode='max',
    verbose=1
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',
    patience=10,  # Stop if validation loss doesn't improve for 10 epochs
    restore_best_weights=True,
    verbose=1
)

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=1e-6, verbose=1)

callbacks = [model_checkpoint_callback, early_stopping_callback, reduce_lr]

Epoch 1/30


2025-03-02 16:01:35.507821: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object

Traceback (most recent call last):

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/li

UnknownError: Graph execution error:

Detected at node EagerPyFunc defined at (most recent call last):
<stack traces unavailable>
Detected at node EagerPyFunc defined at (most recent call last):
<stack traces unavailable>
2 root error(s) found.
  (0) UNKNOWN:  error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object

Traceback (most recent call last):

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_28622/4293546220.py", line 35, in preprocess_image
    image = cv2.imread(image_path)
            ^^^^^^^^^^^^^^^^^^^^^^

cv2.error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object



	 [[{{node EagerPyFunc}}]]
	 [[IteratorGetNext]]
	 [[IteratorGetNext/_2]]
  (1) UNKNOWN:  error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object

Traceback (most recent call last):

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 267, in __call__
    return func(device, token, args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 145, in __call__
    outputs = self._call(device, args)
              ^^^^^^^^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/ops/script_ops.py", line 152, in _call
    ret = self._func(*args)
          ^^^^^^^^^^^^^^^^^

  File "/home/w2sg-arnav/.local/lib/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_28622/4293546220.py", line 35, in preprocess_image
    image = cv2.imread(image_path)
            ^^^^^^^^^^^^^^^^^^^^^^

cv2.error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object



	 [[{{node EagerPyFunc}}]]
	 [[IteratorGetNext]]
0 successful operations.
0 derived errors ignored. [Op:__inference_multi_step_on_iterator_504345]

b/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_28622/4293546220.py", line 35, in preprocess_image
    image = cv2.imread(image_path)
            ^^^^^^^^^^^^^^^^^^^^^^

cv2.error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object



2025-03-02 16:01:35.570963: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object

Traceback (most recent call last):

  File "/home/w2sg-arn

b/python3.12/site-packages/tensorflow/python/autograph/impl/api.py", line 643, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^

  File "/tmp/ipykernel_28622/4293546220.py", line 35, in preprocess_image
    image = cv2.imread(image_path)
            ^^^^^^^^^^^^^^^^^^^^^^

cv2.error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object



2025-03-02 16:01:35.577952: W tensorflow/core/framework/op_kernel.cc:1829] UNKNOWN: error: OpenCV(4.11.0) :-1: error: (-5:Bad argument) in function 'imread'
> Overload resolution failed:
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object
>  - Expected 'filename' to be a str or path-like object

Traceback (most recent call last):

  File "/home/w2sg-arn

# Cell 16: Training Loop

This cell trains the ensemble model.  Since it's an ensemble, we need to provide *both* inputs.

In [None]:
# Cell 17: Training

# Prepare data for the ensemble (both models get the same preprocessed images)
def prepare_ensemble_data(dataset):
    for images, labels in dataset:
        yield [images, images], labels  # Input for both models, and the labels

train_ds_ensemble = train_ds.map(lambda x, y: ([x, x], y))
val_ds_ensemble = val_ds.map(lambda x, y: ([x, x], y))
test_ds_ensemble = test_ds.map(lambda x, y: ([x, x], y))


history = ensemble_model.fit(
    train_ds_ensemble,
    validation_data=val_ds_ensemble,
    epochs=EPOCHS,
    callbacks=callbacks,
    verbose=1  # Show progress
)

Epoch 1/30


ValueError: Argument `output` must have rank (ndim) `target.ndim - 1`. Received: target.shape=(None, 7), output.shape=(None, 7)

# Cell 18:  Evaluation

This cell evaluates the trained model on the test set.

In [None]:
# Cell 19: Evaluation

loss, accuracy = ensemble_model.evaluate(test_ds_ensemble, verbose=1)
print(f"Test Loss: {loss:.4f}, Test Accuracy: {accuracy:.4f}")

# --- Further Evaluation (ROC, AUC, Specificity, Sensitivity) ---
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report

# Get predictions on the test set
y_pred_probs = ensemble_model.predict(test_ds_ensemble)  # Probabilities
y_pred = np.argmax(y_pred_probs, axis=1)  # Predicted classes
y_true = test_labels_int  # True integer labels

# --- Confusion Matrix ---
conf_mat = confusion_matrix(y_true, y_pred)
print("Confusion Matrix:\n", conf_mat)

# --- Classification Report (Precision, Recall, F1-score) ---
class_report = classification_report(y_true, y_pred, target_names=CLASSES)
print("Classification Report:\n", class_report)
# --- ROC Curves and AUC ---
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(NUM_CLASSES):
    fpr[i], tpr[i], _ = roc_curve(test_labels_onehot[:, i], y_pred_probs[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# Plot all ROC curves
plt.figure(figsize=(8, 6))
for i in range(NUM_CLASSES):
    plt.plot(fpr[i], tpr[i], label=f'ROC curve of class {CLASSES[i]} (area = {roc_auc[i]:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC)')
plt.legend(loc="lower right")
plt.show()


# --- Display Training History ---
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.title('Accuracy')

plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.title('Loss')
plt.show()

# Cell 20:  Grad-CAM Visualization (Example)

In [None]:
# Cell 21: Grad-CAM
def make_gradcam_heatmap(img_array, model, last_conv_layer_name, pred_index=None):
    """Generates a Grad-CAM heatmap for a given image and model."""

    # Create a model that maps the input image to the activations of the last conv layer
    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(last_conv_layer_name).output, model.output]
    )

    # Compute the gradient of the top predicted class for the input image
    with tf.GradientTape() as tape:
        last_conv_layer_output, preds = grad_model(img_array)
        if pred_index is None:
            pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    # This is the gradient of the output neuron with regard to the output feature map
    grads = tape.gradient(class_channel, last_conv_layer_output)

    # Pool the gradients over all the axes, leaving out the channel dimension
    pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))

    # Multiply each channel in the feature map by "how important this channel is"
    last_conv_layer_output = last_conv_layer_output.numpy()[0]
    pooled_grads = pooled_grads.numpy()
    for i in range(pooled_grads.shape[-1]):
        last_conv_layer_output[:, :, i] *= pooled_grads[i]

    # Average the channels of the feature map to obtain the heatmap
    heatmap = np.mean(last_conv_layer_output, axis=-1)

    # Normalize the heatmap
    heatmap = np.maximum(heatmap, 0) / np.max(heatmap)
    return heatmap

# --- Example Usage ---
# Choose an image for visualization
img_path = test_files[0]  # Example: Use the first test image
img_array = preprocess_image(img_path)
img_array = np.expand_dims(img_array, axis=0)  # Add batch dimension


# --- Grad-CAM for the CNN model ---
last_conv_layer_name_cnn = "top_activation" # Find the name of the last conv layer in your CNN

heatmap_cnn = make_gradcam_heatmap(img_array, cnn_model, last_conv_layer_name_cnn)

# Display heatmap
plt.matshow(heatmap_cnn)
plt.title("Grad-CAM Heatmap (CNN)")
plt.show()

# --- Grad-CAM for ViT ---
#   ViT's structure is different. You need to access the attention weights.
#   This requires a different approach, and there isn't a single "last convolutional layer."
#   The following is a *simplified conceptual example* and needs adaptation
#   for a specific ViT implementation.  It's often more complex than for CNNs.

def vit_gradcam_simplified(img_array, model, layer_name="encoder_layer_11"):
    """Simplified Grad-CAM for ViT (Conceptual Example)."""

    grad_model = tf.keras.models.Model(
        [model.inputs], [model.get_layer(layer_name).output, model.output]
    )
  
    with tf.GradientTape() as tape:
        outputs, preds = grad_model(img_array)
        pred_index = tf.argmax(preds[0])
        class_channel = preds[:, pred_index]

    grads = tape.gradient(class_channel, outputs)
    #This part is different to a CNN, you need to adapt based on the way ViT is implemented
    weights = tf.reduce_mean(grads, axis=(1)) # Average across the sequence length
    heatmap = tf.reduce_sum(weights * outputs, axis=-1) # Linear combination
    heatmap = tf.squeeze(heatmap, axis=0)

    heatmap = np.maximum(heatmap, 0) / np.max(heatmap)
    return heatmap

# Find an appropriate layer name within your ViT model.
# You might inspect the model summary and choose a layer within the transformer encoder.
last_conv_layer_name_vit = 'transformer_block_11'  # Example layer name -  CHANGE THIS!

if last_conv_layer_name_vit:
    try:
        heatmap_vit = vit_gradcam_simplified(img_array, vit_model, last_conv_layer_name_vit)
        plt.matshow(heatmap_vit)
        plt.title("Grad-CAM Heatmap (ViT - Simplified)")
        plt.show()
    except Exception as e:
      print("Grad-CAM failed for ViT: ", e)
else:
    print("Grad-CAM for ViT skipped - no suitable layer name provided.")


# --- Overlay Heatmap on Original Image (for both CNN and ViT, if available) ---
def display_gradcam(img_path, heatmap, alpha=0.4):
    """Overlays the Grad-CAM heatmap on the original image."""
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    heatmap = cv2.resize(heatmap, (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = heatmap * alpha + img
    superimposed_img = np.clip(superimposed_img, 0, 255).astype(np.uint8)

    plt.imshow(superimposed_img)
    plt.axis('off')

# Display for CNN
display_gradcam(img_path, heatmap_cnn)
plt.title("Grad-CAM Overlay (CNN)")
plt.show()

# Display for ViT (if heatmap was successfully generated)
if 'heatmap_vit' in locals():  # Check if heatmap_vit exists
    display_gradcam(img_path, heatmap_vit)
    plt.title("Grad-CAM Overlay (ViT)")
    plt.show()