In [None]:
import os
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from PIL import Image

sns.set_theme(style="whitegrid")
plt.rcParams['figure.figsize'] = (12, 6)

# Define path to raw data
import sys
sys.path.append("..")
from chest_x_ray_classification.config import RAW_DATA_DIR

DATA_DIR = RAW_DATA_DIR

print(f"Checking directory: {DATA_DIR.resolve()}")
for split in ['train', 'test', 'val']:
    if not (DATA_DIR / split).exists():
        print(f" Warning: Could not find '{split}' folder!")
    else:
        print(f" Found split: {split}")

## Task 1.1 - Quantify Class Imbalance

In [None]:
def get_dataset_stats(data_dir):
    data = []
    for split in ['train', 'val', 'test']:
        split_path = data_dir / split
        
        if not split_path.exists():
            continue
            
        # Dynamically list the classes (normal, pneumonia, tuberculosis)
        classes = [d for d in os.listdir(split_path) if os.path.isdir(split_path / d)]
        
        for label in classes:
            label_path = split_path / label
            
            # Count images in the label directory
            image_files = [
                f for f in os.listdir(label_path) 
                if f.lower().endswith(('.jpeg', '.jpg', '.png'))
            ]
            
            for img in image_files:
                data.append({
                    'split': split,
                    'label': label, 
                    'filename': str(label_path / img)
                })
    
    return pd.DataFrame(data)
df = get_dataset_stats(DATA_DIR)

# Visualization
if not df.empty:
    plt.figure(figsize=(10, 6))
    # Hue will now likely show 3 colors
    ax = sns.countplot(data=df, x='split', hue='label', palette='viridis')
    plt.title('Class Distribution Across Splits (Normal vs Pneumonia vs TB)')
    plt.ylabel('Number of Images')
    
    # Add count labels
    for p in ax.patches:
        height = p.get_height()
        if not np.isnan(height):
             ax.annotate(f'{int(height)}', (p.get_x() + p.get_width() / 2., height),
                         ha='center', va='bottom')
    plt.show()
    
    print("\nExact Counts Table:")
    print(df.groupby(['split', 'label']).size())
else:
    print("No images found. Please verify the folder names.")

## Task 1.1 - Visualize Examples 

In [None]:
def visualize_samples(df, num_samples=5):
    # Get unique labels 
    labels = df['label'].unique()
    n_labels = len(labels)
    
    fig, axes = plt.subplots(n_labels, num_samples, figsize=(15, 4 * n_labels))
    
    for i, label in enumerate(labels):
        # Filter for this class
        class_df = df[df['label'] == label]
        
        # specific check if we have enough samples
        current_samples = min(num_samples, len(class_df))
        if current_samples == 0:
            continue
            
        sample_files = class_df.sample(current_samples, random_state=42)['filename'].values
        
        for j, img_path in enumerate(sample_files):
            img = cv2.imread(img_path)
            if img is None:
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            # Handle axes array depending on dimensions
            if n_labels == 1:
                ax = axes[j]
            else:
                ax = axes[i, j]
                
            ax.imshow(img, cmap='gray')
            if j == 0:
                ax.set_ylabel(label, fontsize=14, fontweight='bold')
            ax.set_title(f"{os.path.basename(img_path)}\n{img.shape}")
            ax.axis('off')
            
    plt.tight_layout()
    plt.show()

if not df.empty:
    visualize_samples(df)

## Task 1.1 - Pixel Intensity Histograms

In [None]:
def plot_intensity_histograms(df, sample_size=1000):
    plt.figure(figsize=(12, 6))
    
    # Sample data to speed up processing
    sample_df = df.sample(min(len(df), sample_size), random_state=42)
    
    for label in df['label'].unique():
        subset = sample_df[sample_df['label'] == label]
        intensities = []
        
        for img_path in subset['filename']:
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            if img is not None:
                intensities.extend(img.flatten())
                
        # Plot KDE (Kernel Density Estimate) for smooth curves
        sns.kdeplot(intensities, label=label, fill=True, alpha=0.3)
        
    plt.title('Pixel Intensity Distribution by Class')
    plt.xlabel('Pixel Intensity (0=Black, 255=White)')
    plt.xlim(0, 255)
    plt.legend()
    plt.show()

if not df.empty:
    plot_intensity_histograms(df)

## Task 1.1 - Edge & Texture Analysis (Laplacian Variance)

In [None]:
def analyze_texture_sharpness(df, sample_size=1000):
    sample_df = df.sample(min(len(df), sample_size), random_state=42)
    results = []

    for idx, row in sample_df.iterrows():
        img = cv2.imread(row['filename'], cv2.IMREAD_GRAYSCALE)
        if img is not None:
            # Laplacian Variance is a standard measure for image blur/sharpness
            laplacian_var = cv2.Laplacian(img, cv2.CV_64F).var()
            results.append({
                'label': row['label'],
                'sharpness': laplacian_var
            })
            
    results_df = pd.DataFrame(results)
    
    plt.figure(figsize=(10, 6))
    sns.boxplot(data=results_df, x='label', y='sharpness', palette='Set2')
    plt.title('Image Sharpness (Laplacian Variance) by Class')
    plt.ylabel('Variance (Higher = Sharper)')
    plt.yscale('log') 
    plt.show()

if not df.empty:
    analyze_texture_sharpness(df)

## Missing/Corrupt Data Handling

In [None]:
def check_for_corrupt_images(df):
    corrupt_files = []
    
    print(f"Scanning {len(df)} images for corruption")
    for idx, row in df.iterrows():
        try:
            # Attempt to open with PIL 
            with Image.open(row['filename']) as img:
                img.verify() # Verify file integrity
        except (IOError, SyntaxError) as e:
            corrupt_files.append(row['filename'])
            print(f" Corrupt file found: {row['filename']}")
            
    if not corrupt_files:
        print("No corrupt images found in the sampled dataset.")
    else:
        print(f"Found {len(corrupt_files)} corrupt images.")
        
    return corrupt_files

if not df.empty:
    corrupt_list = check_for_corrupt_images(df)

## Task 1.2 - The Preprocessing Pipeline Showcase

In [None]:
def preprocess_preview(img_path, target_size=256):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    
    # 1. CLAHE (Contrast Limited Adaptive Histogram Equalization)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    img_clahe = clahe.apply(img)
    
    # 2. Resize with Padding (Preserve Aspect Ratio)
    h, w = img.shape
    scale = target_size / max(h, w)
    new_w, new_h = int(w * scale), int(h * scale)
    img_resized = cv2.resize(img_clahe, (new_w, new_h))
    
    # Create blank canvas
    canvas = np.zeros((target_size, target_size), dtype=np.uint8)
    
    # Center the image
    x_offset = (target_size - new_w) // 2
    y_offset = (target_size - new_h) // 2
    canvas[y_offset:y_offset+new_h, x_offset:x_offset+new_w] = img_resized
    
    return img, canvas

# Visualize Before vs After
if not df.empty:
    sample_file = df['filename'].iloc[0]
    original, processed = preprocess_preview(sample_file)

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(original, cmap='gray')
    ax1.set_title("Original")
    ax2.imshow(processed, cmap='gray')
    ax2.set_title("Processed (CLAHE + Pad/Resize)")
    plt.show()

In [None]:
from skimage.feature import local_binary_pattern

def extract_lbp_features(df, sample_size=5000):
    """
    Extracts LBP texture features to distinguish Normal vs Pneumonia vs Tuberculosis.
    """
    print(f"Extracting LBP texture features for {sample_size} images across 3 classes...")
    
    # LBP Settings
    radius = 3
    n_points = 8 * radius
    
    lbp_means = []
    labels = []
    
    # Balanced sampling: Try to get equal numbers of each class for fair comparison
    unique_labels = df['label'].unique()
    samples_per_class = sample_size // len(unique_labels)
    
    for label in unique_labels:
        # Get subset for this class
        class_subset = df[df['label'] == label]
        
        n = min(len(class_subset), samples_per_class)
        subset = class_subset.sample(n, random_state=42)
        
        for _, row in subset.iterrows():
            img = cv2.imread(row['filename'], cv2.IMREAD_GRAYSCALE)
            if img is not None:
                # Resize for consistency/speed
                img = cv2.resize(img, (224, 224))
                
                # Compute LBP (Uniform method is rotation invariant)
                lbp = local_binary_pattern(img, n_points, radius, method='uniform')
                
                # We use the mean LBP value as a texture proxy
                lbp_means.append(lbp.mean())
                labels.append(label)
            
    # Visualize Results
    results_df = pd.DataFrame({'label': labels, 'lbp_mean': lbp_means})
    
    plt.figure(figsize=(10, 6))
    # Boxplot to show the distribution differences between the 3 classes
    sns.boxplot(data=results_df, x='label', y='lbp_mean', palette='Set2')
    plt.title("Texture Analysis: LBP Mean Values (Normal vs Pneumonia vs TB)")
    plt.ylabel("Mean LBP Value (Texture Complexity)")
    plt.show()

# Run the corrected analysis
if not df.empty:
    extract_lbp_features(df)