In [None]:
import sys

IS_COLAB = 'google.colab' in sys.modules
print(f"Running in Google Colab: {IS_COLAB}")

In [None]:
import platform
import psutil
import subprocess
import os

if IS_COLAB:
    print("Google Colab Environment Specifications:")
    print("="*50)
    
    # Get system info
    
    print(f"Operating System: {platform.system()} {platform.release()}")
    print(f"Architecture: {platform.machine()}")
    print(f"Python Version: {platform.python_version()}")
    
    # Memory info
    memory = psutil.virtual_memory()
    print(f"Total RAM: {memory.total / (1024**3):.1f} GB")
    print(f"Available RAM: {memory.available / (1024**3):.1f} GB")
    
    # CPU info
    print(f"CPU Cores: {psutil.cpu_count(logical=False)} physical, {psutil.cpu_count(logical=True)} logical")
    
    # GPU info
    try:
        result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader,nounits'], 
                              capture_output=True, text=True)
        if result.returncode == 0:
            gpu_info = result.stdout.strip().split('\n')
            for i, gpu in enumerate(gpu_info):
                name, memory = gpu.split(', ')
                print(f"GPU {i}: {name}, {memory} MB VRAM")
        else:
            print("GPU: Not detected or nvidia-smi unavailable")
    except:
        print("GPU: Not detected")
    
    # Disk space
    disk = psutil.disk_usage('/')
    print(f"Disk Space: {disk.free / (1024**3):.1f} GB free / {disk.total / (1024**3):.1f} GB total")
    
    print("="*50)

    if not os.path.exists('/content/aai521_3proj'):
        print("WARNING: Cloning project repository required.")
        print("="*50)
else:
    print("Not running in Google Colab environment")

In [None]:
import os
import sys

if IS_COLAB:
    print("Running in Google Colab environment.")
    if os.path.exists('/content/aai521_3proj'):
        print("Repository already exists. Pulling latest changes...")
        %cd /content/aai521_3proj
        !git pull
    else:
        print("Cloning repository...")
        !git clone https://github.com/swapnilprakashpatil/aai521_3proj.git
        %cd aai521_3proj    
    %pip install -r requirements.txt --quiet
    sys.path.append('/content/aai521_3proj/src')
    %ls
else:
    print("Running in local environment. Installing packages...")
    %pip install -r ../requirements.txt --quiet
    sys.path.append('../src')

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import cv2
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Reload modules to pick up latest changes
import importlib
if 'config' in sys.modules:
    importlib.reload(sys.modules['config'])
if 'data_loader' in sys.modules:
    importlib.reload(sys.modules['data_loader'])
if 'preprocessing' in sys.modules:
    importlib.reload(sys.modules['preprocessing'])
if 'augmentation' in sys.modules:
    importlib.reload(sys.modules['augmentation'])
if 'visualizations' in sys.modules:
    importlib.reload(sys.modules['visualizations'])

# Import custom modules
from config import (
    GERMANY_TRAIN, LOUISIANA_EAST_TRAIN,
    PROCESSED_TRAIN_DIR, CLASS_NAMES, CLASS_COLORS,
    PATCH_SIZE, PATCH_OVERLAP, MIN_FLOOD_PIXELS, SELECTED_REGIONS
)

from data_loader import DatasetLoader, load_tile_data
from preprocessing import ImagePreprocessor, PatchExtractor
from augmentation import get_training_augmentation, DualImageAugmentation
from visualizations import (
    plot_flood_statistics,
    plot_class_distribution,
    plot_augmentation_samples,
    plot_tile_overview,
    plot_clahe_comparison,
    plot_advanced_preprocessing,
    plot_patch_samples,
    plot_processed_sample
)

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

%matplotlib inline

print(f"Configuration loaded: MIN_FLOOD_PIXELS = {MIN_FLOOD_PIXELS}")


In [None]:
germany_loader = DatasetLoader(GERMANY_TRAIN, 'Germany')
louisiana_loader = DatasetLoader(LOUISIANA_EAST_TRAIN, 'Louisiana-East')

print(f"Germany tiles: {len(germany_loader.get_tile_list())}")
print(f"Louisiana-East tiles: {len(louisiana_loader.get_tile_list())}")

In [None]:
germany_stats = germany_loader.get_flood_statistics()
louisiana_stats = louisiana_loader.get_flood_statistics()

stats_df = pd.DataFrame({
    'Germany': germany_stats,
    'Louisiana-East': louisiana_stats
}).T

print(stats_df)

fig = plot_flood_statistics(germany_stats, louisiana_stats)
plt.show()


In [None]:
# Load a sample tile from Germany
sample_tile_name = germany_loader.get_tile_list()[0]
print(f"Loading sample tile: {sample_tile_name}")

sample_data = load_tile_data(GERMANY_TRAIN, sample_tile_name, 'Germany')

print(f"\nTile information:")
print(f"  Pre-image shape: {sample_data['pre_image'].shape}")
print(f"  Post-image shape: {sample_data['post_image'].shape}")
print(f"  Mask shape: {sample_data['mask'].shape}")
print(f"  Pre-image dtype: {sample_data['pre_metadata']['dtype']}")
print(f"  Pre-image range: [{sample_data['pre_image'].min():.3f}, {sample_data['pre_image'].max():.3f}]")

# Check mask classes
unique_classes = np.unique(sample_data['mask'])
print(f"\nMask classes present: {unique_classes}")
for cls in unique_classes:
    count = np.sum(sample_data['mask'] == cls)
    pct = (count / sample_data['mask'].size) * 100
    print(f"  Class {cls} ({CLASS_NAMES.get(cls, 'unknown')}): {count} pixels ({pct:.2f}%)")

In [None]:
fig = plot_tile_overview(sample_data, CLASS_NAMES)
plt.show()


In [None]:
preprocessor = ImagePreprocessor(
    apply_clahe=True,
    clahe_clip_limit=2.0,
    clahe_tile_grid_size=(8, 8)
)

pre_enhanced = preprocessor.apply_clahe_enhancement(sample_data['pre_image'])
post_enhanced = preprocessor.apply_clahe_enhancement(sample_data['post_image'])

fig = plot_clahe_comparison(sample_data, pre_enhanced, post_enhanced)
plt.show()

print("\nCLAHE Enhancement Applied:")
print("  - Improves local contrast")
print("  - Better visibility of flood boundaries")
print("  - Histogram equalization in tiles (8x8)")


In [None]:
# Import advanced image processing libraries
import sys
try:
    from skimage import morphology, filters, exposure, restoration, transform
    from skimage.filters import rank, gaussian
    from skimage.morphology import disk, remove_small_objects, remove_small_holes
    from scipy import ndimage
    from scipy.signal import convolve2d
    print("Advanced image processing libraries loaded successfully")
except ImportError as e:
    print(f"Installing required libraries: {e}")
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "scikit-image", "scipy"])
    from skimage import morphology, filters, exposure, restoration, transform
    from skimage.filters import rank, gaussian
    from skimage.morphology import disk, remove_small_objects, remove_small_holes
    from scipy import ndimage
    from scipy.signal import convolve2d
    print("Libraries installed and loaded")

print("\nAdvanced preprocessing methods:")
print("  1. Multi-stage cloud detection (brightness + texture + saturation)")
print("  2. Morphological cloud refinement")
print("  3. Advanced inpainting (Navier-Stokes + Telea)")
print("  4. Wiener deconvolution for deblurring")
print("  5. Richardson-Lucy deconvolution")
print("  6. Unsharp masking with adaptive strength")
print("  7. CLAHE enhancement per channel")

In [None]:
def create_synthetic_degraded_image(clean_image):
    """
    Create a moderately degraded version to demonstrate preprocessing capabilities
    Adds realistic clouds, haze, and blur while preserving some edge structure
    """
    degraded = clean_image.copy()
    h, w = degraded.shape[:2]
    
    # 1. Add atmospheric haze (reduces contrast and adds blue tint) - REDUCED
    haze_strength = 0.3  # Reduced from 0.5 to preserve more edges
    haze_color = np.array([0.7, 0.75, 0.85])  # Blueish-white
    degraded = degraded * (1 - haze_strength) + haze_color * haze_strength
    
    # 2. Add realistic cloud patches - FEWER AND LIGHTER
    num_clouds = np.random.randint(5, 10)  # Reduced from 8-15
    for _ in range(num_clouds):
        # Random cloud center
        cx, cy = np.random.randint(0, w), np.random.randint(0, h)
        
        # Cloud size - SMALLER
        cloud_w = np.random.randint(60, 150)  # Reduced from 80-200
        cloud_h = np.random.randint(40, 120)  # Reduced from 60-150
        
        # Create cloud mask with soft edges (Gaussian falloff)
        y_coords, x_coords = np.ogrid[:h, :w]
        cloud_mask = np.exp(-((x_coords - cx)**2 / (2 * cloud_w**2) + 
                             (y_coords - cy)**2 / (2 * cloud_h**2)))
        
        # Cloud color (bright white with slight variation)
        cloud_color = np.array([0.85, 0.88, 0.95]) + np.random.uniform(-0.05, 0.05, 3)
        cloud_opacity = np.random.uniform(0.3, 0.6)  # Reduced from 0.5-0.9 for lighter clouds
        
        # Blend cloud
        for c in range(3):
            degraded[:, :, c] = (degraded[:, :, c] * (1 - cloud_mask * cloud_opacity) + 
                                cloud_color[c] * cloud_mask * cloud_opacity)
    
    # 3. Add motion blur (simulating camera/satellite motion) - REDUCED
    kernel_size = 9  # Reduced from 15
    motion_kernel = np.zeros((kernel_size, kernel_size))
    motion_kernel[kernel_size // 2, :] = 1.0 / kernel_size
    
    blurred = np.zeros_like(degraded)
    for c in range(3):
        blurred[:, :, c] = convolve2d(degraded[:, :, c], motion_kernel, mode='same', boundary='symm')
    degraded = blurred
    
    # 4. Add Gaussian noise - REDUCED
    noise = np.random.normal(0, 0.02, degraded.shape)  # Reduced from 0.03
    degraded = degraded + noise
    
    # 5. Reduce overall sharpness - LESS AGGRESSIVE
    degraded = gaussian(degraded, sigma=1.0, channel_axis=2)  # Reduced from 1.5
    
    return np.clip(degraded, 0, 1)


def advanced_cloud_removal(image, aggressive=True):
    """
    State-of-the-art cloud detection and removal
    """
    img_uint8 = (image * 255).astype(np.uint8)
    h, w = img_uint8.shape[:2]
    
    # === MULTI-STAGE CLOUD DETECTION ===
    
    # Stage 1: Brightness analysis
    gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
    if aggressive:
        bright_mask = gray > 160  # Lower threshold for more detection
    else:
        bright_mask = gray > 180
    
    # Stage 2: Blue channel analysis (clouds are blue-white)
    blue_excess = img_uint8[:, :, 2].astype(float) - (img_uint8[:, :, 0].astype(float) + img_uint8[:, :, 1].astype(float)) / 2
    blue_mask = blue_excess > 10
    
    # Stage 3: Texture analysis (clouds have uniform texture)
    selem = disk(7)
    entropy_img = rank.entropy(gray, selem)
    texture_mask = entropy_img < np.percentile(entropy_img, 25)
    
    # Stage 4: Saturation analysis (clouds have low saturation)
    hsv = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2HSV)
    low_sat_mask = hsv[:, :, 1] < 40
    
    # Stage 5: Value analysis (clouds are bright in HSV)
    high_value_mask = hsv[:, :, 2] > 200
    
    # Combine all stages
    cloud_mask = (bright_mask & blue_mask) | (bright_mask & low_sat_mask & texture_mask) | (high_value_mask & low_sat_mask)
    cloud_mask = cloud_mask.astype(np.uint8) * 255
    
    # === MORPHOLOGICAL REFINEMENT ===
    
    # Remove small false positives
    cloud_mask_binary = cloud_mask > 0
    cloud_mask_binary = remove_small_objects(cloud_mask_binary, min_size=100, connectivity=2)
    cloud_mask_binary = remove_small_holes(cloud_mask_binary, area_threshold=200)
    
    # Dilate to ensure full cloud coverage
    selem_dilate = disk(5 if aggressive else 3)
    cloud_mask_binary = morphology.dilation(cloud_mask_binary, selem_dilate)
    
    cloud_mask_final = (cloud_mask_binary * 255).astype(np.uint8)
    
    # === ADVANCED INPAINTING ===
    
    if np.sum(cloud_mask_final > 0) > 200:
        # Method 1: Navier-Stokes (better for texture preservation)
        inpainted_ns = cv2.inpaint(img_uint8, cloud_mask_final, 10, cv2.INPAINT_NS)
        
        # Method 2: Fast Marching (better for structure)
        inpainted_fm = cv2.inpaint(img_uint8, cloud_mask_final, 7, cv2.INPAINT_TELEA)
        
        # Blend both methods
        result = cv2.addWeighted(inpainted_ns, 0.6, inpainted_fm, 0.4, 0)
        
        # Apply bilateral filter for smooth transitions
        result = cv2.bilateralFilter(result, 7, 75, 75)
    else:
        result = img_uint8
    
    return result.astype(np.float32) / 255.0, cloud_mask_final.astype(np.float32) / 255.0


def advanced_deblurring(image, strength='high'):
    """
    Advanced deblurring using multiple state-of-the-art methods
    Proper weight normalization to preserve contrast
    """
    img_uint8 = (image * 255).astype(np.uint8)
    
    # === METHOD 1: WIENER DECONVOLUTION ===
    try:
        # Create motion blur PSF
        kernel_size = 11
        psf = np.zeros((kernel_size, kernel_size))
        psf[kernel_size // 2, :] = 1.0
        psf = psf / psf.sum()
        
        # Apply Wiener deconvolution
        deconvolved = np.zeros_like(image)
        for c in range(3):
            deconv_channel = restoration.wiener(image[:, :, c], psf, balance=0.05)
            deconvolved[:, :, c] = np.clip(deconv_channel, 0, 1)
        
        deconv_uint8 = (deconvolved * 255).astype(np.uint8)
    except:
        deconv_uint8 = img_uint8
    
    # === METHOD 2: RICHARDSON-LUCY DECONVOLUTION ===
    try:
        rl_deconvolved = np.zeros_like(image)
        for c in range(3):
            rl_channel = restoration.richardson_lucy(image[:, :, c], psf, num_iter=15)
            rl_deconvolved[:, :, c] = np.clip(rl_channel, 0, 1)
        
        rl_uint8 = (rl_deconvolved * 255).astype(np.uint8)
    except:
        rl_uint8 = img_uint8
    
    # === METHOD 3: ENHANCED UNSHARP MASKING ===
    gaussian_blur = cv2.GaussianBlur(img_uint8, (9, 9), 2.0)
    if strength == 'high':
        unsharp = cv2.addWeighted(img_uint8, 2.0, gaussian_blur, -1.0, 0)
    else:
        unsharp = cv2.addWeighted(img_uint8, 1.8, gaussian_blur, -0.8, 0)
    unsharp = np.clip(unsharp, 0, 255)
    
    # === METHOD 4: EDGE ENHANCEMENT ===
    gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
    
    # Sobel edge detection
    sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
    sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
    edges = np.sqrt(sobelx**2 + sobely**2)
    edges = np.clip(edges, 0, 255).astype(np.uint8)
    edges_colored = cv2.cvtColor(edges, cv2.COLOR_GRAY2RGB)
    
    edge_enhanced = cv2.addWeighted(img_uint8, 1.0, edges_colored, 0.3, 0)
    edge_enhanced = np.clip(edge_enhanced, 0, 255)
    
    # === METHOD 5: ADAPTIVE CLAHE ===
    clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
    clahe_enhanced = np.zeros_like(img_uint8)
    for c in range(3):
        clahe_enhanced[:, :, c] = clahe.apply(img_uint8[:, :, c])
    
    # === BLEND ALL METHODS WITH NORMALIZED WEIGHTS ===
    # Weights: Wiener (20%) + RL (15%) + Unsharp (35%) + Edge (15%) + CLAHE (15%) = 100%
    result = (deconv_uint8.astype(np.float32) * 0.20 + 
              rl_uint8.astype(np.float32) * 0.15 + 
              unsharp.astype(np.float32) * 0.35 + 
              edge_enhanced.astype(np.float32) * 0.15 + 
              clahe_enhanced.astype(np.float32) * 0.15)
    
    result = np.clip(result, 0, 255).astype(np.uint8)
    
    # === FINAL CONTRAST ENHANCEMENT ===
    # Apply adaptive histogram equalization to boost contrast
    final_clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    result_enhanced = np.zeros_like(result)
    for c in range(3):
        result_enhanced[:, :, c] = final_clahe.apply(result[:, :, c])
    
    return result_enhanced.astype(np.float32) / 255.0


def calculate_quality_metrics(image):
    """Calculate image quality metrics with safe division"""
    img_uint8 = (image * 255).astype(np.uint8)
    gray = cv2.cvtColor(img_uint8, cv2.COLOR_RGB2GRAY)
    
    # Sharpness (Laplacian variance)
    laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
    
    # Contrast (standard deviation)
    contrast = np.std(image)
    
    # Brightness
    brightness = np.mean(image)
    
    # Edge density with minimum threshold to prevent divide by zero
    edges = cv2.Canny(gray, 50, 150)
    edge_density = max(np.sum(edges > 0) / edges.size, 1e-6)  # Minimum 1e-6 to prevent inf
    
    return {
        'sharpness': laplacian_var,
        'contrast': contrast,
        'brightness': brightness,
        'edge_density': edge_density
    }

print("Advanced preprocessing functions defined")

In [None]:
# Apply advanced preprocessing to both pre-event and post-event images
print("Applying advanced preprocessing to BOTH pre-event and post-event images...")
print("="*80)

# Pre-event processing
print("\n[PRE-EVENT IMAGE]")
pre_degraded = create_synthetic_degraded_image(sample_data['pre_image'])
pre_cloud_removed, pre_cloud_mask = advanced_cloud_removal(pre_degraded, aggressive=True)
pre_enhanced = advanced_deblurring(pre_cloud_removed, strength='high')
pre_cloud_cov = np.mean(pre_cloud_mask) * 100

# Calculate metrics
pre_orig_metrics = calculate_quality_metrics(sample_data['pre_image'])
pre_deg_metrics = calculate_quality_metrics(pre_degraded)
pre_enh_metrics = calculate_quality_metrics(pre_enhanced)

print(f"  Cloud coverage detected: {pre_cloud_cov:.1f}%")
print(f"  Sharpness improvement: {((pre_enh_metrics['sharpness']/pre_deg_metrics['sharpness']-1)*100):+.1f}%")
print(f"  Contrast improvement:  {((pre_enh_metrics['contrast']/pre_deg_metrics['contrast']-1)*100):+.1f}%")

# Post-event processing
print("\n[POST-EVENT IMAGE]")
post_degraded = create_synthetic_degraded_image(sample_data['post_image'])
post_cloud_removed, post_cloud_mask = advanced_cloud_removal(post_degraded, aggressive=True)
post_enhanced = advanced_deblurring(post_cloud_removed, strength='high')
post_cloud_cov = np.mean(post_cloud_mask) * 100

# Calculate metrics
post_orig_metrics = calculate_quality_metrics(sample_data['post_image'])
post_deg_metrics = calculate_quality_metrics(post_degraded)
post_enh_metrics = calculate_quality_metrics(post_enhanced)

print(f"  Cloud coverage detected: {post_cloud_cov:.1f}%")
print(f"  Sharpness improvement: {((post_enh_metrics['sharpness']/post_deg_metrics['sharpness']-1)*100):+.1f}%")
print(f"  Contrast improvement:  {((post_enh_metrics['contrast']/post_deg_metrics['contrast']-1)*100):+.1f}%")

print("\n" + "="*80)
print("Preprocessing complete for both images")
print("="*80)

In [None]:
# Prepare metrics for visualization
pre_metrics_dict = {
    'orig': pre_orig_metrics,
    'degraded': pre_deg_metrics,
    'enhanced': pre_enh_metrics
}

post_metrics_dict = {
    'orig': post_orig_metrics,
    'degraded': post_deg_metrics,
    'enhanced': post_enh_metrics
}

fig = plot_advanced_preprocessing(
    sample_data, pre_degraded, pre_enhanced, pre_cloud_removed, pre_cloud_mask,
    post_degraded, post_enhanced, post_cloud_removed, post_cloud_mask,
    pre_metrics_dict, post_metrics_dict
)
plt.show()

# Print comprehensive summary with safe percentage calculations
print("\n" + "="*80)
print("PREPROCESSING EFFECTIVENESS SUMMARY")
print("="*80)

# Safe percentage calculation function
def safe_improvement(enhanced, degraded):
    """Calculate percentage improvement, handling zero/near-zero denominators"""
    if degraded < 1e-6:
        return 0.0
    return ((enhanced / degraded - 1) * 100)

pre_cloud_cov = np.mean(pre_cloud_mask) * 100
post_cloud_cov = np.mean(post_cloud_mask) * 100

print("\nPRE-EVENT IMAGE:")
print(f"    Sharpness improvement:    {safe_improvement(pre_enh_metrics['sharpness'], pre_deg_metrics['sharpness']):+.1f}%")
print(f"    Contrast improvement:     {safe_improvement(pre_enh_metrics['contrast'], pre_deg_metrics['contrast']):+.1f}%")
print(f"    Edge density improvement: {safe_improvement(pre_enh_metrics['edge_density'], pre_deg_metrics['edge_density']):+.1f}%")
print(f"    Cloud coverage removed:   {pre_cloud_cov:.1f}%")

print("\nPOST-EVENT IMAGE:")
print(f"    Sharpness improvement:    {safe_improvement(post_enh_metrics['sharpness'], post_deg_metrics['sharpness']):+.1f}%")
print(f"    Contrast improvement:     {safe_improvement(post_enh_metrics['contrast'], post_deg_metrics['contrast']):+.1f}%")
print(f"    Edge density improvement: {safe_improvement(post_enh_metrics['edge_density'], post_deg_metrics['edge_density']):+.1f}%")
print(f"    Cloud coverage removed:   {post_cloud_cov:.1f}%")

print("\n" + "="*80)
print("TECHNIQUES APPLIED:")
print("  - Multi-stage cloud detection (brightness + blue excess + texture + saturation + value)")
print("  - Morphological refinement (remove_small_objects + remove_small_holes + dilation)")
print("  - Dual inpainting (Navier-Stokes 60% + Telea 40%)")
print("  - Wiener deconvolution (20% weight)")
print("  - Richardson-Lucy deconvolution (15% weight)")
print("  - Enhanced unsharp masking (35% weight)")
print("  - Sobel edge enhancement (15% weight)")
print("  - Adaptive CLAHE (15% weight, clip=3.0)")
print("  - Final CLAHE enhancement (clip=2.0)")
print("="*80)


In [None]:
# Check image quality
quality_pre = preprocessor.check_image_quality(sample_data['pre_image'])
quality_post = preprocessor.check_image_quality(sample_data['post_image'])

print("Pre-Event Quality Metrics:")
for key, value in quality_pre.items():
    print(f"  {key}: {value}")

print("\nPost-Event Quality Metrics:")
for key, value in quality_post.items():
    print(f"  {key}: {value}")

# Visualize quality metrics
metrics = ['valid_ratio', 'cloud_ratio', 'dark_ratio', 'mean_intensity', 'std_intensity']
pre_values = [quality_pre[m] for m in metrics]
post_values = [quality_post[m] for m in metrics]

fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(metrics))
width = 0.35

ax.bar(x - width/2, pre_values, width, label='Pre-Event', color='#3498db')
ax.bar(x + width/2, post_values, width, label='Post-Event', color='#e74c3c')
ax.set_xlabel('Metric')
ax.set_ylabel('Value')
ax.set_title('Image Quality Metrics Comparison')
ax.set_xticks(x)
ax.set_xticklabels(metrics, rotation=45, ha='right')
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Initialize patch extractor with updated threshold
patch_extractor = PatchExtractor(
    patch_size=PATCH_SIZE,
    overlap=PATCH_OVERLAP,
    min_flood_pixels=MIN_FLOOD_PIXELS  # Now 2621 pixels (~1% of patch)
)

print(f"Patch extractor configuration:")
print(f"  Patch size: {PATCH_SIZE}x{PATCH_SIZE}")
print(f"  Overlap: {PATCH_OVERLAP}")
print(f"  Min flood pixels: {MIN_FLOOD_PIXELS} ({(MIN_FLOOD_PIXELS/(PATCH_SIZE**2))*100:.2f}% of patch)")

# Concatenate pre and post images
combined_image = np.concatenate([pre_enhanced, post_enhanced], axis=2)
print(f"\nCombined image shape: {combined_image.shape} (6 channels: 3 pre + 3 post)")

# Extract patches (without oversampling for demonstration)
patches = patch_extractor.extract_patches(
    combined_image,
    mask=sample_data['mask'],
    oversample_flood=False  # Disabled to show true distribution
)

print(f"\nExtracted {len(patches)} patches")

# Count flood-positive patches
flood_positive = [p for p in patches if p['is_flood_positive']]
print(f"  Flood-positive patches: {len(flood_positive)}")
print(f"  Non-flood patches: {len(patches) - len(flood_positive)}")
print(f"  Flood ratio: {len(flood_positive)/len(patches)*100:.1f}%")

In [None]:
print("Mask shape:", sample_data['mask'].shape)
print("Unique classes in mask:", np.unique(sample_data['mask']))
print("Mask value counts:")
unique, counts = np.unique(sample_data['mask'], return_counts=True)
for cls, count in zip(unique, counts):
    pct = (count / sample_data['mask'].size) * 100
    print(f"  Class {cls} ({CLASS_NAMES.get(cls, 'unknown')}): {count:,} pixels ({pct:.2f}%)")

# Check if mask has any flood-related classes (2, 3, 4, 5)
flood_related_pixels = np.sum(sample_data['mask'] > 1)
print(f"\nTotal flood-related pixels (class > 1): {flood_related_pixels:,}")
print(f"Percentage: {(flood_related_pixels / sample_data['mask'].size) * 100:.2f}%")

In [None]:
# Try loading different tiles to find one with mixed flood/non-flood patches
print("Testing different tiles to find varied flood distribution:")
print("="*60)

for idx in range(min(5, len(germany_loader.get_tile_list()))):
    tile_name = germany_loader.get_tile_list()[idx]
    tile_data = load_tile_data(GERMANY_TRAIN, tile_name, 'Germany')
    
    # Calculate flood percentage
    flood_px = np.sum((tile_data['mask'] == 2) | (tile_data['mask'] == 3) | (tile_data['mask'] == 4))
    total_px = tile_data['mask'].size
    flood_pct = (flood_px / total_px) * 100
    
    print(f"\nTile {idx}: {tile_name}")
    print(f"  Flood pixels: {flood_px:,} ({flood_pct:.2f}%)")
    print(f"  Classes present: {np.unique(tile_data['mask'])}")
    
    # If this tile has moderate flooding (2-15%), use it for demo
    if 2.0 <= flood_pct <= 15.0:
        print(f" Good candidate for mixed flood/non-flood patches")
        break

In [None]:
print("\n" + "="*60)
print("Testing patch extraction with different thresholds:")
print("="*60)

for threshold in [100, 2621, 5000, 10000, 20000]:
    test_extractor = PatchExtractor(
        patch_size=PATCH_SIZE,
        overlap=PATCH_OVERLAP,
        min_flood_pixels=threshold
    )
    
    test_patches = test_extractor.extract_patches(
        combined_image,
        mask=sample_data['mask'],
        oversample_flood=True  # ENABLED: Production setting for balanced training
    )
    
    flood_count = sum(1 for p in test_patches if p['is_flood_positive'])
    non_flood_count = len(test_patches) - flood_count
    
    print(f"\nThreshold: {threshold} pixels ({(threshold/(PATCH_SIZE**2))*100:.2f}% of patch)")
    print(f"  Total patches: {len(test_patches)}")
    print(f"  Flood-positive: {flood_count}")
    print(f"  Non-flood: {non_flood_count}")
    print(f"  Flood ratio: {(flood_count/len(test_patches))*100:.1f}%")

In [None]:
print("\nPatch-level flood analysis:")
for i, patch in enumerate(patches):
    flood_px = patch['flood_pixels']
    total_px = PATCH_SIZE * PATCH_SIZE
    flood_pct = (flood_px / total_px) * 100
    print(f"  Patch {i}: {flood_px} flood pixels ({flood_pct:.2f}%)")

print(f"\nPatch size: {PATCH_SIZE}x{PATCH_SIZE} = {PATCH_SIZE*PATCH_SIZE:,} pixels")
print(f"Min flood pixels threshold: {patch_extractor.min_flood_pixels}")
print(f"Min flood percentage needed: {(patch_extractor.min_flood_pixels / (PATCH_SIZE*PATCH_SIZE)) * 100:.2f}%")

In [None]:
fig = plot_patch_samples(patches, n_samples=8, patch_size=PATCH_SIZE)
plt.show()


In [None]:
class_totals = {i: 0 for i in range(7)}

for patch in patches:
    for cls, count in patch.get('class_distribution', {}).items():
        class_totals[int(cls)] += count

fig = plot_class_distribution(patches, CLASS_NAMES, CLASS_COLORS)
plt.show()


In [None]:
train_aug = get_training_augmentation(image_size=PATCH_SIZE)

demo_patch = flood_positive[0] if len(flood_positive) > 0 else patches[0]
demo_image = demo_patch['image'][:, :, :3]
demo_mask = demo_patch['mask']

fig = plot_augmentation_samples(demo_image, demo_mask, train_aug, n_samples=6)
plt.show()


In [None]:

# Run full preprocessing pipeline
print("Starting preprocessing pipeline...")

# Clean up previous preprocessing output before running
from preprocessing import cleanup_processed_data
import config

cleanup_processed_data(config.PROCESSED_DIR)

In [None]:
# Execute the preprocessing script
# This will now:
# 1. Process training data (Germany + Louisiana-East)
# 2. Create train/val split (85%/15%)
# 3. Process test data (Louisiana-West_Test_Public)

if IS_COLAB:
    %run src/run_preprocessing.py
else:    
    %run ../src/run_preprocessing.py

In [None]:
# Process test data
if IS_COLAB:
    %run src/process_test_data.py
else:    
    %run ../src/process_test_data.py


In [None]:
# Validate class balance after preprocessing
from preprocessing import validate_class_balance
import config

validate_class_balance(config.PROCESSED_DIR, config.NUM_CLASSES)


In [None]:
from pathlib import Path

base_dir = Path('../dataset/processed')
splits = ['train', 'val', 'test']

for split in splits:
    split_dir = base_dir / split
    print(f"\n{split.upper()}:")
    
    images_dir = split_dir / 'images'
    masks_dir = split_dir / 'masks'
    if images_dir.exists() and masks_dir.exists():
        n_images = len(list(images_dir.glob('*.npy')))
        n_masks = len(list(masks_dir.glob('*.npy')))
        print(f"  Patches: {n_images} images, {n_masks} masks")
    
    processed_images_dir = split_dir / 'processed_images'
    if processed_images_dir.exists():
        regions = [d.name for d in processed_images_dir.iterdir() if d.is_dir()]
        for region in regions:
            pre_count = len(list((processed_images_dir / region / 'PRE-event').glob('*.tif')))
            post_count = len(list((processed_images_dir / region / 'POST-event').glob('*.tif')))
            print(f"  {region}: {pre_count} PRE, {post_count} POST")

In [None]:
# Run metadata export script
if IS_COLAB:
    %run src/export_metadata.py
else:    
    %run ../src/export_metadata.py

In [None]:
if PROCESSED_TRAIN_DIR.exists():
    print("Processed data directory exists")
    
    train_images = list((PROCESSED_TRAIN_DIR / 'images').glob('*.npy'))
    train_masks = list((PROCESSED_TRAIN_DIR / 'masks').glob('*.npy'))
    
    print(f"\nTraining set:")
    print(f"  Images: {len(train_images)}")
    print(f"  Masks: {len(train_masks)}")
    
    metadata_path = PROCESSED_TRAIN_DIR / 'metadata' / 'train_metadata.json'
    if metadata_path.exists():
        import json
        with open(metadata_path, 'r') as f:
            metadata = json.load(f)
        
        print(f"  Metadata entries: {len(metadata)}")
        
        flood_count = sum(1 for m in metadata if m['is_flood_positive'])
        print(f"  Flood-positive patches: {flood_count} ({flood_count/len(metadata)*100:.1f}%)")
    
    if len(train_images) > 0:
        sample_img = np.load(train_images[0])
        sample_mask = np.load(train_masks[0])
        
        print(f"\nSample patch:")
        print(f"  Image shape: {sample_img.shape}")
        print(f"  Mask shape: {sample_mask.shape}")
        print(f"  Image range: [{sample_img.min():.3f}, {sample_img.max():.3f}]")
        print(f"  Mask classes: {np.unique(sample_mask)}")
        
        fig = plot_processed_sample(sample_img, sample_mask)
        plt.show()
else:
    print("Processed data not found. Run preprocessing first.")


In [None]:
import random
from pathlib import Path
import pandas as pd

# Check if processed full-resolution images exist
if IS_COLAB:
    processed_base = Path('dataset/processed')
else:
    processed_base = Path('../dataset/processed')
train_processed_images = processed_base / 'train' / 'processed_images'

if not train_processed_images.exists():
    print("Processed full-resolution images not found.")
    print("Run preprocessing pipeline first to generate processed images.")
else:
    print("Processed images directory found")
    
    # Get available regions
    available_regions = [d.name for d in train_processed_images.iterdir() if d.is_dir()]
    print(f"Available regions: {available_regions}")
    
    # Map to raw data directories (use actual directory names)
    region_mapping = {
        'Germany_Training_Public': GERMANY_TRAIN,
        'Louisiana-East_Training_Public': LOUISIANA_EAST_TRAIN
    }
    
    # Select 2 random tiles from each region
    comparison_samples = []
    
    for region in available_regions:
        if region in region_mapping:
            raw_dir = region_mapping[region]
            
            # Load CSV mapping file
            csv_name = f"{region}_label_image_mapping.csv"
            csv_path = raw_dir / csv_name
            
            if not csv_path.exists():
                print(f"CSV mapping not found: {csv_path}")
                continue
            
            # Read CSV to get pre/post image mappings
            mapping_df = pd.read_csv(csv_path)
            print(f"\nLoaded {len(mapping_df)} mappings from {csv_name}")
            
            # Get list of processed PRE-event images
            pre_processed_dir = train_processed_images / region / 'PRE-event'
            if not pre_processed_dir.exists():
                print(f"No PRE-event processed images for {region}")
                continue
            
            processed_pre_files = list(pre_processed_dir.glob('*.tif'))
            
            if len(processed_pre_files) == 0:
                print(f"No TIF files found for {region}")
                continue
            
            # Select 2 random samples
            n_samples = min(2, len(processed_pre_files))
            selected_files = random.sample(processed_pre_files, n_samples)
            
            for pre_tif in selected_files:
                # Processed files are named after the pre-event image
                pre_image_name = pre_tif.name  # e.g., "10500500C4DD7000_0_41_59.tif"
                
                # Find matching row in CSV
                matching_row = mapping_df[mapping_df['pre-event image'] == pre_image_name]
                
                if matching_row.empty:
                    print(f"No CSV mapping found for: {pre_image_name}")
                    continue
                
                # Get post-event image name from CSV
                post_image_name = matching_row.iloc[0]['post-event image 1']
                
                # Paths
                raw_pre_path = raw_dir / 'PRE-event' / pre_image_name
                raw_post_path = raw_dir / 'POST-event' / post_image_name
                # Processed POST-event images are saved with their original POST-event filenames
                processed_post_path = train_processed_images / region / 'POST-event' / post_image_name
                
                if all([p.exists() for p in [raw_pre_path, raw_post_path, processed_post_path]]):
                    comparison_samples.append({
                        'region': region,
                        'tile': pre_tif.stem,
                        'raw_pre': raw_pre_path,
                        'raw_post': raw_post_path,
                        'processed_pre': pre_tif,
                        'processed_post': processed_post_path
                    })
                    print(f"Found complete set: {pre_tif.stem}")
                else:
                    missing = []
                    if not raw_pre_path.exists(): missing.append("raw_pre")
                    if not raw_post_path.exists(): missing.append("raw_post")
                    if not processed_post_path.exists(): missing.append("processed_post")
                    print(f"Missing files for {pre_tif.stem}: {', '.join(missing)}")

    print(f"\nFound {len(comparison_samples)} complete sample sets for comparison")
    for sample in comparison_samples:
        print(f"  - {sample['region']}: {sample['tile']}")


In [None]:
def load_tif_image(path):
    """Load TIF image and convert from uint16 to float32"""
    img = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise ValueError(f"Failed to load image: {path}")
    
    # Convert BGR to RGB
    if len(img.shape) == 3:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Convert uint16 [0, 65535] to float32 [0, 1]
    img_float = img.astype(np.float32) / 65535.0
    return img_float


def load_raw_png_image(path):
    """Load raw PNG image"""
    img = cv2.imread(str(path), cv2.IMREAD_COLOR)
    if img is None:
        raise ValueError(f"Failed to load image: {path}")
    
    # Convert BGR to RGB
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    # Convert to float32 [0, 1]
    img_float = img.astype(np.float32) / 255.0
    return img_float


# Visualize all comparison samples
if len(comparison_samples) > 0:
    n_samples = len(comparison_samples)
    
    # Create figure with subplots: 4 columns (Raw Pre, Processed Pre, Raw Post, Processed Post) x n_samples rows
    fig, axes = plt.subplots(n_samples, 4, figsize=(20, 5*n_samples))
    
    # Handle single sample case (axes won't be 2D)
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    for idx, sample in enumerate(comparison_samples):
        try:
            # Load images
            raw_pre = load_raw_png_image(sample['raw_pre'])
            raw_post = load_raw_png_image(sample['raw_post'])
            processed_pre = load_tif_image(sample['processed_pre'])
            processed_post = load_tif_image(sample['processed_post'])
            
            # Display images
            axes[idx, 0].imshow(raw_pre)
            axes[idx, 0].set_title(f"{sample['region']}\n{sample['tile']}\nRaw PRE-event", 
                                   fontsize=11, fontweight='bold')
            axes[idx, 0].axis('off')
            
            axes[idx, 1].imshow(processed_pre)
            axes[idx, 1].set_title(f"Processed PRE-event\n(CLAHE + Cloud Removal + Deblur)", 
                                   fontsize=11, fontweight='bold', color='darkgreen')
            axes[idx, 1].axis('off')
            
            axes[idx, 2].imshow(raw_post)
            axes[idx, 2].set_title(f"Raw POST-event", 
                                   fontsize=11, fontweight='bold')
            axes[idx, 2].axis('off')
            
            axes[idx, 3].imshow(processed_post)
            axes[idx, 3].set_title(f"Processed POST-event\n(CLAHE + Cloud Removal + Deblur)", 
                                   fontsize=11, fontweight='bold', color='darkgreen')
            axes[idx, 3].axis('off')
            
            # Calculate quality improvements
            pre_raw_metrics = calculate_quality_metrics(raw_pre)
            pre_proc_metrics = calculate_quality_metrics(processed_pre)
            post_raw_metrics = calculate_quality_metrics(raw_post)
            post_proc_metrics = calculate_quality_metrics(processed_post)
            
            print(f"\n{sample['region']} - {sample['tile']}:")
            print(f"  PRE-event improvements:")
            print(f"    Sharpness: {pre_raw_metrics['sharpness']:.1f} → {pre_proc_metrics['sharpness']:.1f} "
                  f"({((pre_proc_metrics['sharpness']/pre_raw_metrics['sharpness']-1)*100):+.1f}%)")
            print(f"    Contrast:  {pre_raw_metrics['contrast']:.3f} → {pre_proc_metrics['contrast']:.3f} "
                  f"({((pre_proc_metrics['contrast']/pre_raw_metrics['contrast']-1)*100):+.1f}%)")
            
            print(f"  POST-event improvements:")
            print(f"    Sharpness: {post_raw_metrics['sharpness']:.1f} → {post_proc_metrics['sharpness']:.1f} "
                  f"({((post_proc_metrics['sharpness']/post_raw_metrics['sharpness']-1)*100):+.1f}%)")
            print(f"    Contrast:  {post_raw_metrics['contrast']:.3f} → {post_proc_metrics['contrast']:.3f} "
                  f"({((post_proc_metrics['contrast']/post_raw_metrics['contrast']-1)*100):+.1f}%)")
            
        except Exception as e:
            print(f"Error loading sample {sample['tile']}: {e}")
            for col in range(4):
                axes[idx, col].text(0.5, 0.5, 'Error loading image', 
                                    ha='center', va='center', color='red')
                axes[idx, col].axis('off')
    
    plt.suptitle('Raw vs Processed Full-Resolution Images Comparison\n(Germany & Louisiana-East Datasets)', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.show()
    
    print("\n" + "="*80)
    print("PREPROCESSING EFFECTS SUMMARY")
    print("="*80)
    print("CLAHE Enhancement: Improved local contrast and visibility")
    print("Cloud Removal: Multi-stage detection and advanced inpainting")
    print("Deblurring: Wiener + Richardson-Lucy + Unsharp masking + Edge enhancement")
    print("Format: Saved as TIF (uint16) for quality preservation")
    print("="*80)
else:
    print("\n No comparison samples available. Run preprocessing pipeline first.")