# Chest X-Ray Multi‑Class Project — Role Notebook

**Dataset:** Kaggle “Lungs Disease Dataset (4 types)” by Omkar Manohar Dalvi  
**Classes:** Normal, Bacterial Pneumonia, Viral Pneumonia, COVID‑19, Tuberculosis

> Use this notebook in **Google Colab**. If you’re running locally, adapt the Drive mount steps accordingly.

## Role — Member 2: Preprocessing & Augmentation

**Responsibilities**  
- Implement resizing, normalization, augmentation  
- Produce balanced input pipelines for train/val/test  
- Export sample augmentations for QA  
- Ensure deterministic splits & reproducibility

## Environment & Paths

- The code below mounts Google Drive (for persistence) and prepares base paths.  
- Set `DATASET_DIR` to where the extracted dataset resides (after Kaggle download).

## Augmentation Strategy

We use **tf.data** with `tf.image` for efficient pipelines. Augmentations include:
- Random horizontal flip
- Small rotations
- Random brightness/contrast
- Random zoom/crop

> Keep medical realism: avoid extreme shears/rotations.

In [None]:
# === Colab & Paths ===
import os, sys, glob, json, random, shutil, time
from pathlib import Path

# If in Colab, mount Drive (safe to run elsewhere; it will just fail silently)
try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    IN_COLAB = True
except Exception as e:
    print("Not running on Colab or Drive not available:", e)
    IN_COLAB = False

# Project root inside Drive (you can change this)
PROJECT_ROOT = Path('/content/drive/MyDrive/Chest_XRay_Project')
PROJECT_ROOT.mkdir(parents=True, exist_ok=True)

# Where the dataset will live (after download & unzip). Adjust as needed.
DATASET_DIR = PROJECT_ROOT / 'lungs_dataset'
OUTPUTS_DIR = PROJECT_ROOT / 'outputs'
MODELS_DIR = PROJECT_ROOT / 'models'
REPORTS_DIR = PROJECT_ROOT / 'reports'

for p in [OUTPUTS_DIR, MODELS_DIR, REPORTS_DIR]:
    p.mkdir(parents=True, exist_ok=True)

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATASET_DIR :", DATASET_DIR)
print("OUTPUTS_DIR :", OUTPUTS_DIR)
print("MODELS_DIR  :", MODELS_DIR)
print("REPORTS_DIR :", REPORTS_DIR)

In [None]:
# === TensorFlow setup ===
import tensorflow as tf
print("TF version:", tf.__version__)

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
SEED = 42

CLASS_NAMES = []
for split in ['train']:
    split_dir = DATASET_DIR / split
    if split_dir.exists():
        CLASS_NAMES = sorted([p.name for p in split_dir.iterdir() if p.is_dir()])
        break
print("Detected classes:", CLASS_NAMES)

# === tf.data loaders ===
def make_ds(split, shuffle=True):
    split_dir = str(DATASET_DIR / split)
    ds = tf.keras.utils.image_dataset_from_directory(
        split_dir,
        labels='inferred',
        label_mode='categorical',
        batch_size=BATCH_SIZE,
        image_size=IMG_SIZE,
        shuffle=shuffle,
        seed=SEED
    )
    return ds

train_ds = make_ds('train', shuffle=True)
val_ds   = make_ds('val',   shuffle=False)
test_ds  = make_ds('test',  shuffle=False)

# === Normalization layer ===
normalizer = tf.keras.layers.Rescaling(1./255)

# === Augmentation — using Keras preprocessing layers ===
data_augment = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.05),
    tf.keras.layers.RandomZoom(0.05),
    tf.keras.layers.RandomContrast(0.1),
], name="augmentation")

def prep_pipeline(ds, training=False):
    AUTOTUNE = tf.data.AUTOTUNE
    ds = ds.map(lambda x, y: (normalizer(x), y), num_parallel_calls=AUTOTUNE)
    if training:
        ds = ds.map(lambda x, y: (data_augment(x, training=True), y), num_parallel_calls=AUTOTUNE)
        ds = ds.shuffle(1000, seed=SEED)
    return ds.prefetch(AUTOTUNE)

train_ds = prep_pipeline(train_ds, training=True)
val_ds   = prep_pipeline(val_ds, training=False)
test_ds  = prep_pipeline(test_ds, training=False)

# Save class names for other notebooks
with open(PROJECT_ROOT / 'classes.json', 'w') as f:
    json.dump(CLASS_NAMES, f)
print("Saved classes.json:", PROJECT_ROOT / 'classes.json')

In [None]:
# === Preview a few augmented samples ===
import matplotlib.pyplot as plt

def show_batch(ds, title="Batch preview"):
    images, labels = next(iter(ds))
    plt.figure(figsize=(8,8))
    for i in range(9):
        plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy())
        idx = tf.argmax(labels[i]).numpy()
        plt.title(CLASS_NAMES[idx])
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

show_batch(train_ds, title="Augmented training samples")