# 01: Data Exploration

This notebook explores the ThickBloodSmears_150 dataset, visualizes class distribution,
examines sample images, and analyzes image properties to understand the dataset characteristics.

## Objectives
1. Load and explore dataset structure
2. Visualize class distribution
3. Display sample infected and uninfected slides
4. Analyze image properties (resolution, color distribution)
5. Check for data quality issues
6. Assess preprocessing requirements

In [None]:
# Import required libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import cv2
from pathlib import Path
from collections import Counter
import json

# Add src to path for imports
sys.path.insert(0, os.path.join(os.getcwd(), '..'))

# Import from project
from src.data.dataset_loader import ThickBloodSmearsLoader

# Set style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì Dependencies loaded successfully")

## 1. Dataset Overview

In [None]:
# Initialize the dataset loader
print("Loading ThickBloodSmears_150 dataset...")
print("="*60)

# Define dataset path - update this to match your local dataset location
dataset_path = Path("../data/raw/ThickBloodSmears_150")  # Adjust as needed

# Check if dataset exists
if not dataset_path.exists():
    print(f"‚ö†Ô∏è  Dataset not found at {dataset_path}")
    print("Please download the ThickBloodSmears_150 dataset and place it at the path above")
    dataset_path = None
else:
    loader = ThickBloodSmearsLoader(
        data_dir=str(dataset_path),
        image_size=224
    )
    print(f"‚úì Dataset loaded from {dataset_path}")
    print(f"Total images: {len(loader.image_files)}")
    print(f"Classes: {loader.class_names}")

## 2. Class Distribution Analysis

In [None]:
if dataset_path and dataset_path.exists():
    # Count images per class
    class_counts = Counter()
    for label in loader.labels:
        class_counts[label] += 1
    
    print("Class Distribution:")
    print("="*40)
    for class_idx, class_name in enumerate(loader.class_names):
        count = class_counts.get(class_idx, 0)
        percentage = (count / len(loader.labels)) * 100
        bar = "‚ñà" * int(percentage / 5)
        print(f"{class_name:15} | {count:3d} images ({percentage:5.1f}%) {bar}")
    
    # Visualize class distribution
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Bar plot
    classes = [loader.class_names[i] for i in range(len(loader.class_names))]
    counts = [class_counts.get(i, 0) for i in range(len(loader.class_names))]
    colors = ['#2ecc71', '#e74c3c']  # Green for uninfected, red for infected
    
    axes[0].bar(classes, counts, color=colors, alpha=0.7, edgecolor='black', linewidth=2)
    axes[0].set_ylabel('Number of Images', fontsize=12, fontweight='bold')
    axes[0].set_title('Class Distribution', fontsize=14, fontweight='bold')
    axes[0].grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, (class_name, count) in enumerate(zip(classes, counts)):
        axes[0].text(i, count + 1, str(count), ha='center', fontweight='bold')
    
    # Pie chart
    axes[1].pie(counts, labels=classes, colors=colors, autopct='%1.1f%%',
                startangle=90, textprops={'fontsize': 11, 'fontweight': 'bold'})
    axes[1].set_title('Class Distribution (Percentage)', fontsize=14, fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate imbalance ratio
    class_ratios = {loader.class_names[i]: counts[i] for i in range(len(loader.class_names))}
    max_count = max(counts)
    imbalance_ratio = max_count / min(counts) if min(counts) > 0 else float('inf')
    
    print(f"\nImbalance Ratio: {imbalance_ratio:.2f}:1")
    if imbalance_ratio > 1.5:
        print("‚ö†Ô∏è  Significant class imbalance detected - weighted loss and stratified splitting recommended")
    else:
        print("‚úì Classes are relatively balanced")

## 3. Sample Images Visualization

In [None]:
if dataset_path and dataset_path.exists():
    # Get sample indices for each class
    samples_per_class = 3
    class_indices = {i: [] for i in range(len(loader.class_names))}
    
    for idx, label in enumerate(loader.labels):
        if len(class_indices[label]) < samples_per_class:
            class_indices[label].append(idx)
    
    # Display samples
    fig, axes = plt.subplots(len(loader.class_names), samples_per_class,
                             figsize=(15, 8))
    
    for class_idx, class_name in enumerate(loader.class_names):
        for sample_idx, img_idx in enumerate(class_indices[class_idx]):
            image_path = loader.image_files[img_idx]
            img = cv2.imread(str(image_path))
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            ax = axes[class_idx, sample_idx]
            ax.imshow(img)
            ax.set_title(f"{class_name}\n({image_path.name})", fontsize=10, fontweight='bold')
            ax.axis('off')
    
    plt.suptitle('Sample Images by Class', fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()
    
    print("‚úì Sample images displayed")

## 4. Image Properties Analysis

In [None]:
if dataset_path and dataset_path.exists():
    print("Analyzing image properties...")
    print("="*60)
    
    image_properties = {
        'heights': [],
        'widths': [],
        'channels': [],
        'file_sizes': [],
        'formats': []
    }
    
    for image_path in loader.image_files:
        img = cv2.imread(str(image_path))
        if img is not None:
            h, w, c = img.shape
            image_properties['heights'].append(h)
            image_properties['widths'].append(w)
            image_properties['channels'].append(c)
            
            file_size = os.path.getsize(image_path) / 1024  # in KB
            image_properties['file_sizes'].append(file_size)
            
            fmt = image_path.suffix.lower()
            image_properties['formats'].append(fmt)
    
    # Statistics
    print("\nImage Dimensions:")
    print(f"  Height: {np.mean(image_properties['heights']):.1f} ¬± {np.std(image_properties['heights']):.1f}")
    print(f"          (min: {np.min(image_properties['heights'])}, max: {np.max(image_properties['heights'])})")
    print(f"  Width:  {np.mean(image_properties['widths']):.1f} ¬± {np.std(image_properties['widths']):.1f}")
    print(f"          (min: {np.min(image_properties['widths'])}, max: {np.max(image_properties['widths'])})")
    
    print(f"\nImage Channels: {Counter(image_properties['channels'])}")
    
    print(f"\nFile Sizes:")
    print(f"  Mean: {np.mean(image_properties['file_sizes']):.1f} KB")
    print(f"  Range: {np.min(image_properties['file_sizes']):.1f} - {np.max(image_properties['file_sizes']):.1f} KB")
    
    print(f"\nImage Formats: {Counter(image_properties['formats'])}")
    
    # Visualize dimensions
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    axes[0, 0].hist(image_properties['heights'], bins=20, color='skyblue', edgecolor='black', alpha=0.7)
    axes[0, 0].set_xlabel('Height (pixels)', fontweight='bold')
    axes[0, 0].set_ylabel('Count', fontweight='bold')
    axes[0, 0].set_title('Image Height Distribution', fontweight='bold')
    axes[0, 0].grid(alpha=0.3)
    
    axes[0, 1].hist(image_properties['widths'], bins=20, color='lightcoral', edgecolor='black', alpha=0.7)
    axes[0, 1].set_xlabel('Width (pixels)', fontweight='bold')
    axes[0, 1].set_ylabel('Count', fontweight='bold')
    axes[0, 1].set_title('Image Width Distribution', fontweight='bold')
    axes[0, 1].grid(alpha=0.3)
    
    axes[1, 0].hist(image_properties['file_sizes'], bins=20, color='lightgreen', edgecolor='black', alpha=0.7)
    axes[1, 0].set_xlabel('File Size (KB)', fontweight='bold')
    axes[1, 0].set_ylabel('Count', fontweight='bold')
    axes[1, 0].set_title('File Size Distribution', fontweight='bold')
    axes[1, 0].grid(alpha=0.3)
    
    formats_count = Counter(image_properties['formats'])
    axes[1, 1].bar(formats_count.keys(), formats_count.values(),
                   color='plum', edgecolor='black', alpha=0.7)
    axes[1, 1].set_xlabel('File Format', fontweight='bold')
    axes[1, 1].set_ylabel('Count', fontweight='bold')
    axes[1, 1].set_title('Image Format Distribution', fontweight='bold')
    axes[1, 1].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()

## 5. Color Distribution Analysis

In [None]:
if dataset_path and dataset_path.exists():
    print("Analyzing color distribution (Giemsa staining patterns)...")
    print("="*60)
    
    # Analyze color channels for each class
    color_stats = {}
    for class_idx, class_name in enumerate(loader.class_names):
        class_mask = np.array(loader.labels) == class_idx
        class_indices = np.where(class_mask)[0][:5]  # First 5 images of class
        
        r_values, g_values, b_values = [], [], []
        
        for idx in class_indices:
            image_path = loader.image_files[idx]
            img = cv2.imread(str(image_path))
            img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            
            r_values.extend(img_rgb[:,:,0].flatten())
            g_values.extend(img_rgb[:,:,1].flatten())
            b_values.extend(img_rgb[:,:,2].flatten())
        
        color_stats[class_name] = {
            'R': np.mean(r_values),
            'G': np.mean(g_values),
            'B': np.mean(b_values)
        }
    
    # Display color statistics
    print("\nMean Color Values (RGB):")
    for class_name, rgb in color_stats.items():
        print(f"  {class_name:15}: R={rgb['R']:.1f}, G={rgb['G']:.1f}, B={rgb['B']:.1f}")
    
    # Visualize color distributions
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    channels = ['R', 'G', 'B']
    
    for i, channel in enumerate(channels):
        for class_name, rgb in color_stats.items():
            axes[i].bar(class_name, rgb[channel], alpha=0.7, label=class_name)
        
        axes[i].set_ylabel(f'{channel} Channel Intensity', fontweight='bold')
        axes[i].set_title(f'{channel} Channel Distribution', fontweight='bold')
        axes[i].set_ylim([0, 255])
        axes[i].grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    print("\nüìå Giemsa staining produces characteristic colors:")
    print("  - RBCs: Pale/transparent")
    print("  - Parasites: Deep blue/purple nucleus, pink/red cytoplasm")
    print("  - WBCs: Blue nucleus, light blue cytoplasm")

## 6. Data Quality Assessment

In [None]:
if dataset_path and dataset_path.exists():
    print("Data Quality Assessment")
    print("="*60)
    
    quality_issues = []
    
    # Check for corrupted images
    corrupted_count = 0
    for image_path in loader.image_files:
        try:
            img = cv2.imread(str(image_path))
            if img is None:
                corrupted_count += 1
                quality_issues.append(f"Corrupted: {image_path.name}")
        except:
            corrupted_count += 1
            quality_issues.append(f"Error reading: {image_path.name}")
    
    print(f"Corrupted images: {corrupted_count}/{len(loader.image_files)}")
    if corrupted_count == 0:
        print("‚úì All images readable")
    else:
        print(f"‚ö†Ô∏è  {corrupted_count} corrupted images found")
    
    # Check for extreme image sizes
    heights = image_properties['heights']
    widths = image_properties['widths']
    aspect_ratios = [h/w for h, w in zip(heights, widths)]
    
    print(f"\nImage size consistency:")
    if np.std(heights) > np.mean(heights) * 0.1:
        print("‚ö†Ô∏è  Heights vary significantly (need resizing)")
    else:
        print("‚úì Heights relatively consistent")
    
    if np.std(widths) > np.mean(widths) * 0.1:
        print("‚ö†Ô∏è  Widths vary significantly (need resizing)")
    else:
        print("‚úì Widths relatively consistent")
    
    print(f"\nAspect ratio range: {np.min(aspect_ratios):.2f} - {np.max(aspect_ratios):.2f}")
    if np.std(aspect_ratios) < 0.1:
        print("‚úì Aspect ratios consistent (mostly square)")
    else:
        print("‚ö†Ô∏è  Variable aspect ratios detected")
    
    # Summary
    print("\n" + "="*60)
    print("DATA QUALITY SUMMARY")
    print("="*60)
    
    if len(quality_issues) == 0:
        print("‚úì Overall data quality: GOOD")
        print("  - No corrupted images")
        print("  - Consistent dimensions")
        print("  - Ready for preprocessing")
    else:
        print(f"‚ö†Ô∏è  Issues detected: {len(quality_issues)}")
        for issue in quality_issues[:5]:
            print(f"  - {issue}")
        if len(quality_issues) > 5:
            print(f"  ... and {len(quality_issues) - 5} more issues")

## 7. Summary and Recommendations

In [None]:
if dataset_path and dataset_path.exists():
    print("\n" + "="*60)
    print("DATASET SUMMARY & PREPROCESSING RECOMMENDATIONS")
    print("="*60)
    
    recommendations = []
    
    # Recommendation 1: Image resizing
    print(f"\n1. IMAGE RESIZING")
    print(f"   Current size: variable")
    print(f"   Recommended: 224√ó224 pixels")
    print(f"   Reason: Standard input for transfer learning models")
    recommendations.append("Resize all images to 224√ó224")
    
    # Recommendation 2: Class imbalance
    print(f"\n2. CLASS IMBALANCE HANDLING")
    if imbalance_ratio > 1.5:
        print(f"   Imbalance ratio: {imbalance_ratio:.2f}:1")
        print(f"   Recommended: Weighted loss functions")
        print(f"   Also: Stratified train/val/test split")
        recommendations.append("Use weighted BCE or focal loss")
        recommendations.append("Implement stratified splitting (70/15/15)")
    else:
        print(f"   Imbalance ratio: {imbalance_ratio:.2f}:1")
        print(f"   Status: Relatively balanced")
    
    # Recommendation 3: Augmentation
    print(f"\n3. DATA AUGMENTATION")
    print(f"   Dataset size: {len(loader.image_files)} images")
    print(f"   Recommended: Medical-safe augmentation")
    print(f"   - Rotation: 0-360¬∞")
    print(f"   - Flips: Horizontal & vertical")
    print(f"   - Brightness/Contrast: ¬±20%")
    print(f"   - Elastic deformations: moderate")
    recommendations.append("Apply medical-safe augmentation (rotation, flips, brightness)")
    
    # Recommendation 4: Preprocessing
    print(f"\n4. MICROSCOPY-SPECIFIC PREPROCESSING")
    print(f"   Step 1: CLAHE (Contrast Limited Adaptive Histogram)")
    print(f"   Step 2: Color normalization (handle staining variations)")
    print(f"   Step 3: Pixel normalization (ImageNet standard)")
    recommendations.append("Apply CLAHE for contrast enhancement")
    recommendations.append("Normalize color (Giemsa staining variation)")
    
    # Recommendation 5: Model selection
    print(f"\n5. MODEL ARCHITECTURE")
    print(f"   Small dataset ({len(loader.image_files)} images) with medical focus")
    print(f"   Recommended:")
    print(f"   - Transfer Learning (ResNet50, DenseNet121) - PRIMARY")
    print(f"   - Medical CNN with attention mechanism - SECONDARY")
    print(f"   - Ensemble of multiple models - OPTIMAL")
    recommendations.append("Use transfer learning with pre-trained weights")
    recommendations.append("Consider ensemble of 3-5 models for production")
    
    # Recommendation 6: Validation strategy
    print(f"\n6. VALIDATION & EVALUATION")
    print(f"   Metric priority:")
    print(f"   1. Sensitivity (Recall) - CRITICAL for screening")
    print(f"   2. Specificity - Important for reducing false alarms")
    print(f"   3. NPV (Negative Predictive Value) - Key for clinical trust")
    print(f"   Split strategy: 70% train / 15% val / 15% test (stratified)")
    recommendations.append("Use stratified k-fold cross-validation")
    recommendations.append("Prioritize sensitivity in early stopping and model selection")
    
    print(f"\n" + "="*60)
    print(f"NEXT STEPS:")
    print(f"="*60)
    for i, rec in enumerate(recommendations, 1):
        print(f"{i}. {rec}")
    
    print(f"\nProceeding to: 02_image_preprocessing.ipynb")

In [None]:
print("\n‚úì Data exploration complete!")
print("\nKey findings:")
print(f"  ‚Ä¢ Total images: {len(loader.image_files) if dataset_path and dataset_path.exists() else 'N/A'}")
print(f"  ‚Ä¢ Classes: Infected / Uninfected (binary classification)")
print(f"  ‚Ä¢ Staining: Giemsa (thick blood smear)")
print(f"  ‚Ä¢ Magnification: √ó1000 (oil immersion)")
print(f"\nReady for preprocessing in next notebook!")