# Chest X-Ray Multi‑Class Project — Role Notebook

**Dataset:** Kaggle “Lungs Disease Dataset (4 types)” by Omkar Manohar Dalvi  
**Classes:** Normal, Bacterial Pneumonia, Viral Pneumonia, COVID‑19, Tuberculosis

> Use this notebook in **Google Colab**. If you’re running locally, adapt the Drive mount steps accordingly.

## Role — Member 5: Explainability & Visualization

**Responsibilities**  
- Implement **Grad‑CAM** for saliency/attention visualization  
- Generate heatmaps for correct and incorrect predictions per class  
- Create augmentation & performance visual assets for the report/presentation

## Environment & Paths

- The code below mounts Google Drive (for persistence) and prepares base paths.  
- Set `DATASET_DIR` to where the extracted dataset resides (after Kaggle download).

## Grad‑CAM Approach

We compute Grad‑CAM on the last convolutional layer of the backbone and overlay on the input image.

In [None]:
# === Colab & Paths ===
import os, sys, glob, json, random, shutil, time
from pathlib import Path

# If in Colab, mount Drive (safe to run elsewhere; it will just fail silently)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    IN_COLAB = True
except Exception as e:
    print("Not running on Colab or Drive not available:", e)
    IN_COLAB = False

# Project root inside Drive (you can change this)
PROJECT_ROOT = Path('/content/drive/MyDrive/Chest_XRay_Project')
PROJECT_ROOT.mkdir(parents=True, exist_ok=True)

# Where the dataset will live (after download & unzip). Adjust as needed.
DATASET_DIR = PROJECT_ROOT / 'lungs_dataset'
OUTPUTS_DIR = PROJECT_ROOT / 'outputs'
MODELS_DIR = PROJECT_ROOT / 'models'
REPORTS_DIR = PROJECT_ROOT / 'reports'

for p in [OUTPUTS_DIR, MODELS_DIR, REPORTS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATASET_DIR :", DATASET_DIR)
print("OUTPUTS_DIR :", OUTPUTS_DIR)
print("MODELS_DIR  :", MODELS_DIR)
print("REPORTS_DIR :", REPORTS_DIR)

In [None]:
import tensorflow as tf, numpy as np, matplotlib.pyplot as plt, json
from tensorflow.keras.preprocessing import image
from pathlib import Path
from PIL import Image

# Load classes & best model
with open(PROJECT_ROOT / 'classes.json') as f:
    CLASS_NAMES = json.load(f)

best_model_path = MODELS_DIR / 'best_model.keras'
model = tf.keras.models.load_model(best_model_path)
model.summary()

# Helper: find last conv layer name
def find_last_conv(model):
    for layer in reversed(model.layers):
        if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.SeparableConv2D, tf.keras.layers.DepthwiseConv2D)):
            return layer.name
    # Try scanning nested layers
    for layer in reversed(model.layers):
        try:
            for sub in reversed(layer.layers):
                if isinstance(sub, (tf.keras.layers.Conv2D, tf.keras.layers.SeparableConv2D, tf.keras.layers.DepthwiseConv2D)):
                    return sub.name
        except Exception:
            pass
    raise ValueError("No Conv layer found")

last_conv_name = find_last_conv(model)
print("Last conv layer:", last_conv_name)

IMG_SIZE = (224,224)

def load_img(path):
    img = Image.open(path).convert('RGB').resize(IMG_SIZE)
    arr = np.array(img)/255.0
    return arr[np.newaxis,...], img

def grad_cam(model, img_array, last_conv_layer_name, class_index=None):
    grad_model = tf.keras.models.Model(
        [model.inputs],
        [model.get_layer(last_conv_layer_name).output, model.output]
    )
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        if class_index is None:
            class_index = tf.argmax(predictions[0])
        loss = predictions[:, class_index]

    grads = tape.gradient(loss, conv_outputs)
    pooled_grads = tf.reduce_mean(grads, axis=(0,1,2))
    conv_outputs = conv_outputs[0]
    heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1).numpy()
    heatmap = np.maximum(heatmap, 0)
    heatmap /= (heatmap.max() + 1e-8)
    return heatmap, int(class_index)

def overlay_heatmap(img_pil, heatmap, alpha=0.35):
    # Resize heatmap to image size and overlay
    heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize(img_pil.size)
    heatmap = np.array(heatmap)/255.0
    heatmap = np.stack([heatmap]*3, axis=-1)
    img = np.array(img_pil)/255.0
    over = (1 - alpha) * img + alpha * heatmap
    over = np.clip(over, 0, 1)
    return over

# Demo on a few test images per class
test_root = DATASET_DIR / 'test'
examples = []
for cls in CLASS_NAMES:
    cls_dir = test_root / cls
    files = list(cls_dir.glob('*.png')) + list(cls_dir.glob('*.jpg')) + list(cls_dir.glob('*.jpeg'))
    if files:
        examples.append(files[0])

for p in examples:
    arr, img_pil = load_img(p)
    preds = model.predict(arr, verbose=0)[0]
    pred_idx = int(np.argmax(preds))
    heat, ci = grad_cam(model, arr, last_conv_name, class_index=pred_idx)
    over = overlay_heatmap(img_pil, heat)

    plt.figure()
    plt.imshow(over)
    plt.title(f"True: {p.parent.name} | Pred: {CLASS_NAMES[pred_idx]} | Conf: {preds[pred_idx]:.3f}")
    plt.axis('off')
    plt.show()

In [None]:
# === Augmentation gallery (reuses Member 2 layers) ===
augment = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.05),
    tf.keras.layers.RandomZoom(0.05),
    tf.keras.layers.RandomContrast(0.1),
])

def show_aug_gallery(path):
    arr, img_pil = load_img(path)
    imgs = [img_pil]
    for _ in range(5):
        x = augment(np.array([np.array(img_pil)/255.0]), training=True).numpy()[0]
        imgs.append(Image.fromarray((x*255).astype('uint8')))

    for i, im in enumerate(imgs):
        plt.figure()
        plt.imshow(im)
        plt.title(f"Augmented sample {i}")
        plt.axis('off')
        plt.show()

if examples:
    show_aug_gallery(str(examples[0]))