In [None]:
import pandas as pd

# Path to your Excel file
file_path = r'C:\Users\steph\Downloads\EFFIS_All_Fires_Compiled_all.xlsx'

# Read the Excel file
df = pd.read_excel(file_path)

# Print all column names so you can identify the exact ones
print("Columns in the file:")
print(df.columns.tolist())

# UPDATE THESE WITH THE EXACT COLUMN NAMES FROM THE OUTPUT ABOVE
lat_col = 'Lat'      # e.g., 'latitude', 'LAT', 'Latitude'
lon_col = 'Lon'      # e.g., 'longitude', 'LON', 'Longitude'
date_col = 'Date'    # e.g., 'date', 'FIRE_DATE', 'Date '

# Extract only the three columns
extracted_df = df[[lat_col, lon_col, date_col]].copy()

# Drop rows with any missing values
extracted_df = extracted_df.dropna()

# Convert Date to proper datetime (handles various formats)
extracted_df[date_col] = pd.to_datetime(extracted_df[date_col], errors='coerce')

# No file output, no previews printed
print(f"Done! Loaded and cleaned {len(extracted_df)} rows into 'extracted_df'.")
print("You can now use the DataFrame in your script, e.g., extracted_df.head()")

In [None]:
import os
import numpy as np
from IPython.display import display
from PIL import Image as PILImage, ImageDraw, ImageFont

# Configuration
PREVIEW_DIR = r'C:\Users\steph\Downloads\master_thesis\scripts\sentinel2_fire_images\preview'
PROCESSED_DIR = os.path.join(PREVIEW_DIR, 'processed')

# Create processed directory if it doesn't exist
os.makedirs(PROCESSED_DIR, exist_ok=True)

def extract_date_from_filename(filename):
    """Extract date from filename."""
    import re
    date_match = re.search(r'(\d{8})', filename)
    if date_match:
        date_str = date_match.group(1)
        return f"{date_str[:4]}-{date_str[4:6]}-{date_str[6:8]}"
    return "Unknown Date"

def extract_image_type_from_filename(filename):
    """Extract image type (before/after) from filename."""
    if 'before' in filename.lower():
        return "Before Fire"
    elif 'after' in filename.lower():
        return "After Fire"
    return "Unknown"

def extract_fire_id_from_filename(filename):
    """Extract fire ID from filename."""
    import re
    # Look for patterns like _001_, _002_, etc.
    fire_match = re.search(r'_(\d{3})_', filename)
    if fire_match:
        return fire_match.group(1)
    return "Unknown"

def add_scale_bar_to_image(pil_image, resolution_m, scale_km=2):
    """Add a scale bar to the PIL image."""
    try:
        # Create a copy to draw on
        image_with_scale = pil_image.copy()
        draw = ImageDraw.Draw(image_with_scale)
        
        # Calculate scale bar dimensions
        scale_pixels = int((scale_km * 1000) / resolution_m)
        
        # Position in bottom right corner (5% from right, 5% from bottom)
        width, height = image_with_scale.size
        x_pos = width - int(width * 0.05) - scale_pixels
        y_pos = height - int(height * 0.05)
        
        # Draw scale bar (white rectangle with black border)
        bar_height = int(height * 0.01)  # 1% of image height
        bar_width = scale_pixels
        
        # Draw the scale bar
        draw.rectangle([x_pos, y_pos, x_pos + bar_width, y_pos - bar_height], 
                      fill='white', outline='black', width=2)
        
        # Add scale text
        try:
            # Try to use a font
            font = ImageFont.load_default()
            text = f"{scale_km} km"
            text_width = draw.textlength(text, font=font)
            text_x = x_pos + (bar_width - text_width) // 2
            text_y = y_pos - bar_height - 20
            
            # Draw text background
            padding = 4
            draw.rectangle([text_x - padding, text_y - padding, 
                          text_x + text_width + padding, text_y + 12 + padding], 
                         fill='black')
            
            # Draw text
            draw.text((text_x, text_y), text, fill='white', font=font)
        except:
            # Fallback without font
            text_x = x_pos + bar_width // 2 - 10
            text_y = y_pos - bar_height - 15
            draw.rectangle([text_x - 15, text_y - 2, text_x + 25, text_y + 10], fill='black')
            draw.text((text_x, text_y), f"{scale_km} km", fill='white')
        
        return image_with_scale
        
    except Exception as e:
        print(f"  ‚ö† Could not add scale bar: {e}")
        return pil_image

def find_image_files():
    """Find all TIFF files with before/after in the name."""
    if not os.path.exists(PREVIEW_DIR):
        return [], []
    
    all_files = os.listdir(PREVIEW_DIR)
    
    # More flexible pattern matching
    before_files = []
    after_files = []
    
    for file in all_files:
        if (file.endswith('.tif') or file.endswith('.tiff')) and 'square_10km' in file:
            if 'before' in file.lower():
                before_files.append(file)
            elif 'after' in file.lower():
                after_files.append(file)
    
    return before_files, after_files

def get_rgb_bands(src):
    """Extract and return proper RGB bands in correct order."""
    print(f"  - Available bands: {src.count}")
    
    # Try to identify bands by their descriptions
    band_descriptions = []
    if src.descriptions:
        band_descriptions = [desc for desc in src.descriptions]
        print(f"  - Band descriptions: {band_descriptions}")
    
    # Look for specific band names in descriptions
    red_band = None
    green_band = None
    blue_band = None
    
    for i, desc in enumerate(band_descriptions):
        if desc and 'B04' in desc:
            red_band = i + 1
        elif desc and 'B03' in desc:
            green_band = i + 1
        elif desc and 'B02' in desc:
            blue_band = i + 1
    
    # If we found all RGB bands by description
    if red_band and green_band and blue_band:
        print(f"  - Using band descriptions:")
        print(f"    Band {blue_band}: {band_descriptions[blue_band-1]} (Blue)")
        print(f"    Band {green_band}: {band_descriptions[green_band-1]} (Green)")
        print(f"    Band {red_band}: {band_descriptions[red_band-1]} (Red)")
        red = src.read(red_band)
        green = src.read(green_band)
        blue = src.read(blue_band)
        return red, green, blue
    
    # Fallback: assume standard Sentinel-2 order
    if src.count >= 4:
        # Assume standard Sentinel-2 order: B01, B02, B03, B04, etc.
        # So for RGB we need bands 2, 3, 4 (Blue, Green, Red)
        blue_band = 2  # B02
        green_band = 3  # B03
        red_band = 4    # B04
        
        if src.count >= red_band:
            red = src.read(red_band)
            green = src.read(green_band)
            blue = src.read(blue_band)
            print(f"  - Using standard Sentinel-2 band order:")
            print(f"    Band {blue_band}: Blue (B02)")
            print(f"    Band {green_band}: Green (B03)")
            print(f"    Band {red_band}: Red (B04)")
            return red, green, blue
    
    # Fallback: try first 3 bands
    if src.count >= 3:
        print("  - Using first 3 bands as RGB")
        red = src.read(1)
        green = src.read(2)
        blue = src.read(3)
        return red, green, blue
    
    # If only 1-2 bands, duplicate to create grayscale RGB
    elif src.count >= 1:
        print("  - Single band image, creating grayscale RGB")
        single_band = src.read(1)
        return single_band, single_band, single_band
    
    else:
        raise ValueError("No bands found in image")

def enhance_rgb_contrast(red, green, blue):
    """Apply contrast enhancement to RGB bands individually."""
    def enhance_band(band):
        # Use percentiles to avoid outliers
        valid_pixels = band[band > 0]
        if len(valid_pixels) > 0:
            p2 = np.percentile(valid_pixels, 2)
            p98 = np.percentile(valid_pixels, 98)
            if p98 > p2:
                enhanced = np.clip((band - p2) / (p98 - p2), 0, 1)
                return (enhanced * 255).astype(np.uint8)
        # Fallback: simple normalization
        band_min = band.min()
        band_max = band.max()
        if band_max > band_min:
            return ((band - band_min) / (band_max - band_min) * 255).astype(np.uint8)
        else:
            return band.astype(np.uint8)
    
    red_enhanced = enhance_band(red)
    green_enhanced = enhance_band(green)
    blue_enhanced = enhance_band(blue)
    
    return red_enhanced, green_enhanced, blue_enhanced

def save_processed_image(pil_image, original_filename, suffix=""):
    """Save processed image as PNG to the processed folder."""
    # Create new filename
    base_name = os.path.splitext(original_filename)[0]
    if suffix:
        new_filename = f"{base_name}_{suffix}.png"
    else:
        new_filename = f"{base_name}.png"
    
    output_path = os.path.join(PROCESSED_DIR, new_filename)
    pil_image.save(output_path, 'PNG')
    print(f"  üíæ Saved: {new_filename}")
    return output_path

def display_all_images_with_messages():
    """Display all images with proper grouping and clear 'no image' messages."""
    print("\n" + "="*60)
    print("üåç DISPLAYING ALL FIRE IMAGES WITH SCALE BARS")
    print("="*60)
    
    # Find all image files
    before_files, after_files = find_image_files()
    all_files = before_files + after_files
    
    if not all_files:
        print("‚ùå No square_10km TIFF files found in the directory.")
        print("\nüìÅ Available files in directory:")
        available_files = [f for f in os.listdir(PREVIEW_DIR) if f.endswith(('.tif', '.tiff'))]
        if available_files:
            for file in sorted(available_files):
                print(f"  - {file}")
        else:
            print("  No TIFF files found.")
        return
    
    print(f"üìä Found {len(before_files)} BEFORE and {len(after_files)} AFTER images")
    
    # Group files by fire ID
    fire_groups = {}
    for img_file in all_files:
        fire_id = extract_fire_id_from_filename(img_file)
        if fire_id not in fire_groups:
            fire_groups[fire_id] = {'before': [], 'after': []}
        
        if 'before' in img_file.lower():
            fire_groups[fire_id]['before'].append(img_file)
        elif 'after' in img_file.lower():
            fire_groups[fire_id]['after'].append(img_file)
    
    # Display images grouped by fire ID
    images_processed = 0
    
    for fire_id in sorted(fire_groups.keys()):
        print(f"\n{'='*50}")
        print(f"üî• FIRE LOCATION {fire_id}")
        print(f"{'='*50}")
        
        # Display before images for this fire
        before_images = fire_groups[fire_id]['before']
        if before_images:
            print(f"\nüì∏ BEFORE FIRE IMAGES:")
            for img_file in sorted(before_images):
                if process_and_display_image(img_file):
                    images_processed += 1
        else:
            print(f"\n‚ùå No BEFORE images available for fire {fire_id}")
        
        # Display after images for this fire
        after_images = fire_groups[fire_id]['after']
        if after_images:
            print(f"\nüì∏ AFTER FIRE IMAGES:")
            for img_file in sorted(after_images):
                if process_and_display_image(img_file):
                    images_processed += 1
        else:
            print(f"\n‚ùå No AFTER images available for fire {fire_id}")
    
    return images_processed

def process_and_display_image(img_file):
    """Process and display a single image. Returns True if successful."""
    img_path = os.path.join(PREVIEW_DIR, img_file)
    
    display_type = extract_image_type_from_filename(img_file)
    date_str = extract_date_from_filename(img_file)
    fire_id = extract_fire_id_from_filename(img_file)
    
    try:
        import rasterio
        
        with rasterio.open(img_path) as src:
            print(f"\nüîÑ Processing: {img_file}")
            
            # Get RGB bands
            red, green, blue = get_rgb_bands(src)
            
            # Create enhanced RGB
            red_enhanced, green_enhanced, blue_enhanced = enhance_rgb_contrast(red, green, blue)
            rgb_enhanced = np.dstack((red_enhanced, green_enhanced, blue_enhanced))
            
            # Convert to PIL Image
            enhanced_img = PILImage.fromarray(rgb_enhanced)
            
            # Add scale bar to ALL images
            resolution_m = abs(src.transform[0])
            enhanced_with_scale = add_scale_bar_to_image(enhanced_img, resolution_m, scale_km=2)
            
            # Save the processed image with scale
            saved_path = save_processed_image(enhanced_with_scale, img_file, "enhanced_scale")
            
            # Display information and image
            print(f"üìä {display_type} - Fire {fire_id} - {date_str}")
            print(f"üéØ Resolution: {resolution_m:.1f} meters/pixel")
            print(f"üìè Scale bar: 2 km")
            print(f"üñºÔ∏è  Image size: {enhanced_with_scale.size}")
            display(enhanced_with_scale)
            print(f"{'='*40}")
            return True
            
    except Exception as e:
        print(f"‚ùå Error processing {img_file}: {e}")
        import traceback
        traceback.print_exc()
        return False

def create_comparison_grid(before_images, after_images):
    """Create a comparison grid showing before/after pairs side by side."""
    if not before_images and not after_images:
        return None
    
    print(f"\nüìä CREATING COMPARISON GRID")
    print(f"   Before images: {len(before_images)}")
    print(f"   After images: {len(after_images)}")
    
    # This is a placeholder for grid creation functionality
    # In a full implementation, you would resize images to same dimensions
    # and create a matplotlib subplot grid
    print("   üîß Grid comparison feature would be implemented here")
    return None

def main():
    """Main function to display and save processed RGB images."""
    print("üåç DISPLAYING ALL FIRE IMAGES WITH SCALE BARS")
    print("=" * 60)
    print(f"üìÅ Source directory: {PREVIEW_DIR}")
    print(f"üíæ Output directory: {PROCESSED_DIR}")
    print("=" * 60)
    print("üìè Scale bars: 2 km (added to ALL images)")
    print("üé® Images: Enhanced contrast + Scale bars")
    print("=" * 60)
    
    # Check if directory exists
    if not os.path.exists(PREVIEW_DIR):
        print(f"‚ùå Directory not found: {PREVIEW_DIR}")
        print("Please run the extraction script first to generate image files.")
        return
    
    # Display all images with proper messaging
    images_processed = display_all_images_with_messages()
    
    # Print summary
    before_files, after_files = find_image_files()
    all_files = before_files + after_files
    
    print("\n" + "="*60)
    print("‚úÖ PROCESSING COMPLETED")
    print("="*60)
    
    if images_processed > 0:
        print(f"üìä Successfully processed {images_processed} images:")
        print(f"   - {len(before_files)} BEFORE fire images")
        print(f"   - {len(after_files)} AFTER fire images")
        print(f"üìè Scale bars added to ALL images (2 km)")
        print(f"üíæ All enhanced images with scale bars saved to: {PROCESSED_DIR}")
        
        # List saved files
        saved_files = [f for f in os.listdir(PROCESSED_DIR) if f.endswith('.png')]
        if saved_files:
            print(f"\nüìÅ Saved PNG files (with scale bars):")
            for file in sorted(saved_files):
                print(f"   - {file}")
                
        # Show scale bar information
        print(f"\nüìê Scale Bar Information:")
        print(f"   - Length: 2 kilometers")
        print(f"   - Position: Bottom-right corner")
        print(f"   - Color: White bar with black border")
        print(f"   - Text: White on black background")
        
    else:
        print("‚ùå No images were successfully processed.")
        print("üí° Please check that:")
        print("   - The extraction script has been run successfully")
        print("   - TIFF files exist in the preview directory")
        print("   - Files follow the naming pattern: square_10km_allbands_[before/after]_[fire_id]_[date].tif")

# Run the main function
if __name__ == "__main__":
    main()

In [None]:
import os
import numpy as np
from PIL import Image as PILImage, ImageDraw
import rasterio
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Configuration - UPDATED OUTPUT DIRECTORY
PREVIEW_DIR = r'C:\Users\steph\Downloads\master_thesis\scripts\sentinel2_fire_images\preview'
CLASSIFIED_DIR = r'C:\Users\steph\Downloads\master_thesis\scripts\sentinel2_fire_images\classification_results'
os.makedirs(CLASSIFIED_DIR, exist_ok=True)

# Land cover class definitions
CLASSES = {
    1: {'name': 'Fields/Agriculture', 'color': (255, 255, 100)},
    2: {'name': 'Coniferous Forest', 'color': (0, 100, 0)},
    3: {'name': 'Deciduous Forest', 'color': (50, 205, 50)},
    4: {'name': 'Urban/Bare Soil', 'color': (169, 169, 169)}
}

def add_scale_bar_to_plot(ax, transform, scale_km=2):
    """Add a scale bar to matplotlib plot."""
    try:
        # Get image dimensions in data coordinates
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        
        # Calculate pixel size in meters using the transform
        pixel_size_m = abs(transform[0])  # meters per pixel
        
        # Calculate scale bar length in pixels
        scale_length_pixels = (scale_km * 1000) / pixel_size_m
        
        # Position in bottom right corner (5% from right, 5% from bottom)
        x_range = xlim[1] - xlim[0]
        y_range = ylim[1] - ylim[0]
        
        x_pos = xlim[1] - x_range * 0.05 - scale_length_pixels
        y_pos = ylim[0] + y_range * 0.05
        
        # Draw scale bar
        rect = plt.Rectangle((x_pos, y_pos), scale_length_pixels, y_range * 0.005,
                           facecolor='white', edgecolor='black', linewidth=2)
        ax.add_patch(rect)
        
        # Add scale text
        ax.text(x_pos + scale_length_pixels / 2, y_pos - y_range * 0.02,
               f'{scale_km} km', ha='center', va='top', color='white', 
               fontweight='bold', fontsize=10,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='black', alpha=0.7))
        
    except Exception as e:
        print(f"  ‚ö† Could not add scale bar to plot: {e}")

def identify_sentinel2_bands_correctly(bands):
    """Correctly identify Sentinel-2 bands - FIXED VERSION"""
    print(f"  - Number of bands: {len(bands)}")
    
    band_info = {}
    
    if len(bands) >= 12:
        # Standard 12-band Sentinel-2
        band_info = {
            'blue': bands[1],   # Band 2 - B02
            'green': bands[2],  # Band 3 - B03
            'red': bands[3],    # Band 4 - B04
            'nir': bands[7]     # Band 8 - B08
        }
        print("  - Identified as 12-band Sentinel-2 (using B2, B3, B4, B8)")
        
    elif len(bands) >= 4:
        # Try to identify by wavelength characteristics
        band_means = [np.mean(band) for band in bands[:4]]
        
        # NIR usually has highest reflectance in vegetation areas
        # Red usually has lower reflectance
        nir_idx = np.argmax(band_means)
        red_idx = np.argmin(band_means[:3])  # Red is usually in first 3 bands
        
        remaining = [i for i in range(4) if i not in [nir_idx, red_idx]]
        blue_idx = remaining[0]
        green_idx = remaining[1]
        
        band_info = {
            'blue': bands[blue_idx],
            'green': bands[green_idx],
            'red': bands[red_idx],
            'nir': bands[nir_idx]
        }
        print(f"  - Identified by reflectance: B{blue_idx+1}=Blue, B{green_idx+1}=Green, B{red_idx+1}=Red, B{nir_idx+1}=NIR")
        
    else:
        # Fallback - assume standard order
        band_info = {
            'blue': bands[0] if len(bands) > 0 else None,
            'green': bands[1] if len(bands) > 1 else None,
            'red': bands[2] if len(bands) > 2 else None,
            'nir': bands[3] if len(bands) > 3 else None
        }
        print("  - Using standard band order assumption")
    
    return band_info

def calculate_ndvi_safe(red, nir):
    """Calculate NDVI with safety checks and normalization."""
    # Ensure we're working with float arrays
    red = red.astype(np.float32)
    nir = nir.astype(np.float32)
    
    # Handle division by zero
    denominator = nir + red
    valid_mask = denominator > 0
    
    ndvi = np.zeros_like(red, dtype=np.float32)
    ndvi[valid_mask] = (nir[valid_mask] - red[valid_mask]) / denominator[valid_mask]
    
    # NDVI should be between -1 and 1
    ndvi = np.clip(ndvi, -1, 1)
    
    return ndvi

def classify_with_corrected_ndvi(band_info):
    """Classification with CORRECTED NDVI thresholds."""
    red = band_info.get('red')
    nir = band_info.get('nir')
    
    if red is None or nir is None:
        print("‚ùå Missing Red or NIR bands for classification")
        return np.zeros(red.shape, dtype=np.uint8), np.zeros(red.shape)
    
    print(f"  - Red band range: [{red.min():.1f}, {red.max():.1f}]")
    print(f"  - NIR band range: [{nir.min():.1f}, {nir.max():.1f}]")
    
    # Calculate proper NDVI
    ndvi = calculate_ndvi_safe(red, nir)
    
    print(f"  - NDVI range: [{ndvi.min():.3f}, {ndvi.max():.3f}]")
    print(f"  - NDVI mean: {ndvi.mean():.3f}")
    
    # Initialize classification
    classification = np.ones(red.shape, dtype=np.uint8)  # Default to Fields
    
    # CORRECTED NDVI-BASED CLASSIFICATION
    # Based on your feedback: 0.1-0.3 for fields, over 0.3 for forest
    urban_mask = ndvi < 0.1
    fields_mask = (ndvi >= 0.1) & (ndvi < 0.3)  # FIELDS: 0.1 to 0.3
    forest_mask = ndvi >= 0.3                   # FORESTS: ‚â• 0.3
    
    # Within forests, distinguish coniferous vs deciduous
    if np.any(forest_mask):
        # Use NIR/Red ratio to distinguish forest types
        nir_red_ratio = np.where(red == 0, 1, nir / red)
        
        # Coniferous typically has lower NIR/Red ratio
        coniferous_mask = forest_mask & (nir_red_ratio < 3.0)
        deciduous_mask = forest_mask & (nir_red_ratio >= 3.0)
    else:
        coniferous_mask = np.zeros_like(forest_mask)
        deciduous_mask = np.zeros_like(forest_mask)
    
    # Apply classifications
    classification[urban_mask] = 4          # Urban/Bare Soil
    classification[coniferous_mask] = 2     # Coniferous Forest
    classification[deciduous_mask] = 3      # Deciduous Forest
    classification[fields_mask] = 1         # Fields/Agriculture
    
    # Print classification statistics for verification
    print(f"  - Classification preview:")
    print(f"    Urban/Bare Soil: {np.sum(urban_mask):,} pixels ({np.sum(urban_mask)/urban_mask.size*100:.1f}%)")
    print(f"    Fields/Agriculture: {np.sum(fields_mask):,} pixels ({np.sum(fields_mask)/fields_mask.size*100:.1f}%)")
    print(f"    Forests: {np.sum(forest_mask):,} pixels ({np.sum(forest_mask)/forest_mask.size*100:.1f}%)")
    print(f"    - Coniferous: {np.sum(coniferous_mask):,} pixels")
    print(f"    - Deciduous: {np.sum(deciduous_mask):,} pixels")
    
    return classification, ndvi

def create_true_color_rgb(band_info):
    """Create TRUE COLOR RGB (Red, Green, Blue bands)."""
    red = band_info.get('red')
    green = band_info.get('green')
    blue = band_info.get('blue')
    
    if red is not None and green is not None and blue is not None:
        def enhance_band(band):
            # Use percentiles to avoid outliers
            p2 = np.percentile(band, 2)
            p98 = np.percentile(band, 98)
            enhanced = np.clip((band - p2) / (p98 - p2), 0, 1)
            return (enhanced * 255).astype(np.uint8)
        
        red_enhanced = enhance_band(red)
        green_enhanced = enhance_band(green)
        blue_enhanced = enhance_band(blue)
        
        # TRUE COLOR: Red, Green, Blue
        return np.dstack((red_enhanced, green_enhanced, blue_enhanced))
    else:
        print("‚ùå Missing bands for true color RGB")
        return None

def create_false_color_rgb(band_info):
    """Create FALSE COLOR RGB (NIR, Red, Green bands)."""
    nir = band_info.get('nir')
    red = band_info.get('red')
    green = band_info.get('green')
    
    if nir is not None and red is not None and green is not None:
        def enhance_band(band):
            p2 = np.percentile(band, 2)
            p98 = np.percentile(band, 98)
            enhanced = np.clip((band - p2) / (p98 - p2), 0, 1)
            return (enhanced * 255).astype(np.uint8)
        
        nir_enhanced = enhance_band(nir)
        red_enhanced = enhance_band(red)
        green_enhanced = enhance_band(green)
        
        # FALSE COLOR: NIR as Red, Red as Green, Green as Blue
        return np.dstack((nir_enhanced, red_enhanced, green_enhanced))
    else:
        print("‚ùå Missing bands for false color RGB")
        return None

def create_ndvi_colormap(ndvi):
    """Create a colored NDVI visualization."""
    # Normalize NDVI from -1 to 1 to 0 to 1
    ndvi_normalized = (ndvi + 1) / 2
    ndvi_normalized = np.clip(ndvi_normalized, 0, 1)
    
    # Use RdYlGn colormap (Red-Yellow-Green)
    cmap = plt.cm.RdYlGn
    ndvi_colored = (cmap(ndvi_normalized) * 255).astype(np.uint8)[:, :, :3]  # Remove alpha channel
    return ndvi_colored

def create_classification_colormap(classification):
    """Create a colored classification visualization."""
    height, width = classification.shape
    class_rgb = np.zeros((height, width, 3), dtype=np.uint8)
    for class_id, class_info in CLASSES.items():
        mask = classification == class_id
        class_rgb[mask] = class_info['color']
    return class_rgb

def create_comprehensive_dashboard(true_color, false_color, classification, stats, ndvi, img_file, image_type, transform):
    """Create a comprehensive dashboard figure with scale bars."""
    # Create visualizations
    ndvi_vis = create_ndvi_colormap(ndvi)
    class_vis = create_classification_colormap(classification)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 12))
    fig.suptitle(f'{image_type.upper()} FIRE - Land Cover Classification - {img_file}', 
                 fontsize=18, fontweight='bold', y=0.98)
    
    # Create grid layout
    gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)
    
    # Plot 1: TRUE COLOR RGB (Red, Green, Blue)
    ax1 = fig.add_subplot(gs[0, 0])
    if true_color is not None:
        ax1.imshow(true_color)
        add_scale_bar_to_plot(ax1, transform)
    ax1.set_title('TRUE COLOR (Red, Green, Blue)', fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Plot 2: FALSE COLOR (NIR, Red, Green)
    ax2 = fig.add_subplot(gs[0, 1])
    if false_color is not None:
        ax2.imshow(false_color)
        add_scale_bar_to_plot(ax2, transform)
    ax2.set_title('FALSE COLOR (NIR, Red, Green)', fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Plot 3: NDVI Visualization
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(ndvi_vis)
    add_scale_bar_to_plot(ax3, transform)
    ax3.set_title('NDVI Map (Red=Low, Green=High Vegetation)', fontsize=12, fontweight='bold')
    ax3.axis('off')
    
    # Plot 4: Classification Map
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.imshow(class_vis)
    add_scale_bar_to_plot(ax4, transform)
    ax4.set_title('Land Cover Classification', fontsize=12, fontweight='bold')
    ax4.axis('off')
    
    # Plot 5: Classification Statistics
    ax5 = fig.add_subplot(gs[1, 1])
    classes = [CLASSES[class_id]['name'] for class_id in sorted(stats.keys())]
    percentages = [stats[class_id]['percentage'] for class_id in sorted(stats.keys())]
    colors = [tuple(c/255 for c in CLASSES[class_id]['color']) for class_id in sorted(stats.keys())]
    
    bars = ax5.bar(classes, percentages, color=colors, edgecolor='black', alpha=0.7)
    ax5.set_title('Land Cover Distribution (%)', fontsize=12, fontweight='bold')
    ax5.set_ylabel('Percentage (%)')
    ax5.tick_params(axis='x', rotation=45)
    ax5.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, percentage in zip(bars, percentages):
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{percentage:.1f}%', ha='center', va='bottom', fontweight='bold')
    
    # Plot 6: Text summary
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.axis('off')
    
    # Calculate summary statistics
    vegetation_classes = [1, 2, 3]
    total_vegetation = sum(stats[c]['pixels'] for c in vegetation_classes)
    total_vegetation_pct = (total_vegetation / classification.size) * 100
    
    # NDVI distribution
    ndvi_urban = np.sum(ndvi < 0.1) / ndvi.size * 100
    ndvi_fields = np.sum((ndvi >= 0.1) & (ndvi < 0.3)) / ndvi.size * 100
    ndvi_forest = np.sum(ndvi >= 0.3) / ndvi.size * 100
    
    summary_text = [
        "LAND COVER SUMMARY:",
        f"Total Vegetation: {total_vegetation_pct:.1f}%",
        f"Forests: {stats[2]['percentage'] + stats[3]['percentage']:.1f}%",
        f"Fields: {stats[1]['percentage']:.1f}%",
        f"Urban/Bare Soil: {stats[4]['percentage']:.1f}%",
        "",
        "NDVI DISTRIBUTION:",
        f"NDVI < 0.1 (Urban): {ndvi_urban:.1f}%",
        f"NDVI 0.1-0.3 (Fields): {ndvi_fields:.1f}%",
        f"NDVI ‚â• 0.3 (Forests): {ndvi_forest:.1f}%",
        f"NDVI Range: [{ndvi.min():.3f}, {ndvi.max():.3f}]",
        "",
        "SCALE INFORMATION:",
        f"Scale bars: 2 km (all images)",
        f"Resolution: {abs(transform[0]):.1f} m/pixel"
    ]
    
    ax6.text(0.02, 0.95, "\n".join(summary_text), transform=ax6.transAxes, 
             fontsize=11, verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.7))
    
    return fig

def save_dashboard_as_png(true_color, false_color, classification, stats, ndvi, img_file, image_type, output_dir, transform):
    """Save only the dashboard as PNG."""
    base_name = os.path.splitext(img_file)[0]
    
    # Create and save comprehensive dashboard
    fig = create_comprehensive_dashboard(true_color, false_color, classification, stats, ndvi, img_file, image_type, transform)
    dashboard_path = os.path.join(output_dir, f"{base_name}_dashboard.png")
    fig.savefig(dashboard_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    
    print(f"üíæ Dashboard saved as: {dashboard_path}")
    return dashboard_path

def display_corrected_results(true_color, false_color, classification, stats, ndvi, img_file, image_type, transform):
    """Display results with CORRECTED NDVI thresholds and save dashboard."""
    # Create classification image
    class_rgb = create_classification_colormap(classification)
    
    # Display header with image type
    display(HTML(f"<h2>üåç {image_type.upper()} FIRE - {img_file}</h2>"))
    
    # Create and display the dashboard
    fig = create_comprehensive_dashboard(true_color, false_color, classification, stats, ndvi, img_file, image_type, transform)
    plt.show()
    
    # Save only dashboard as PNG
    dashboard_path = save_dashboard_as_png(true_color, false_color, classification, stats, ndvi, img_file, image_type, CLASSIFIED_DIR, transform)
    
    # Display statistics in console WITH EMOJIS
    print("\n" + "="*70)
    print(f"üìä {image_type.upper()} FIRE - CLASSIFICATION STATISTICS")
    print("="*70)
    
    # Sort by percentage (descending)
    sorted_stats = sorted(stats.items(), key=lambda x: x[1]['percentage'], reverse=True)
    
    for class_id, stat in sorted_stats:
        print(f"üè∑Ô∏è  {stat['name']:20} {stat['percentage']:6.2f}% ({stat['pixels']:>8,} pixels)")
    
    # Vegetation analysis
    vegetation_classes = [1, 2, 3]
    total_vegetation = sum(stats[c]['pixels'] for c in vegetation_classes)
    total_vegetation_pct = (total_vegetation / classification.size) * 100
    
    print(f"\nüåø VEGETATION ANALYSIS:")
    print(f"   Total Vegetation: {total_vegetation_pct:.1f}%")
    print(f"   Forests: {stats[2]['percentage'] + stats[3]['percentage']:.1f}%")
    print(f"   Fields: {stats[1]['percentage']:.1f}%")
    print(f"   Urban/Bare Soil: {stats[4]['percentage']:.1f}%")
    
    # CORRECTED NDVI distribution
    print(f"\nüéØ CORRECTED NDVI DISTRIBUTION:")
    ndvi_urban = np.sum(ndvi < 0.1) / ndvi.size * 100
    ndvi_fields = np.sum((ndvi >= 0.1) & (ndvi < 0.3)) / ndvi.size * 100  # FIELDS: 0.1-0.3
    ndvi_forest = np.sum(ndvi >= 0.3) / ndvi.size * 100                   # FORESTS: ‚â•0.3
    
    print(f"   NDVI < 0.1 (Urban):        {ndvi_urban:.1f}%")
    print(f"   NDVI 0.1-0.3 (Fields):     {ndvi_fields:.1f}%")
    print(f"   NDVI ‚â• 0.3 (Forests):      {ndvi_forest:.1f}%")
    print(f"   NDVI Range: [{ndvi.min():.3f}, {ndvi.max():.3f}]")
    
    # Scale information
    print(f"\nüìè SCALE INFORMATION:")
    print(f"   Scale bars: 2 km (added to all images)")
    print(f"   Resolution: {abs(transform[0]):.1f} meters/pixel")
    
    # Verify classification matches NDVI distribution
    print(f"\n‚úÖ VERIFICATION:")
    print(f"   Fields classification:    {stats[1]['percentage']:.1f}%")
    print(f"   Fields from NDVI:         {ndvi_fields:.1f}%")
    print(f"   Forests classification:   {(stats[2]['percentage'] + stats[3]['percentage']):.1f}%")
    print(f"   Forests from NDVI:        {ndvi_forest:.1f}%")
    
    print("="*70)
    
    return stats, dashboard_path

def process_image_corrected(img_file, image_type):
    """Process image with CORRECTED NDVI thresholds."""
    img_path = os.path.join(PREVIEW_DIR, img_file)
    
    try:
        with rasterio.open(img_path) as src:
            print(f"\nüîÑ PROCESSING {image_type.upper()} FIRE: {img_file}")
            print(f"  - Image shape: {src.shape}")
            print(f"  - Number of bands: {src.count}")
            print(f"  - Data type: {src.dtypes[0]}")
            print(f"  - Resolution: {abs(src.transform[0]):.1f} m/pixel")
            
            # Read all bands
            bands = []
            for i in range(1, src.count + 1):
                band_data = src.read(i)
                print(f"  - Band {i}: range [{band_data.min():.1f}, {band_data.max():.1f}]")
                bands.append(band_data)
            
            # Correct band identification
            band_info = identify_sentinel2_bands_correctly(bands)
            
            # CORRECTED classification with proper NDVI thresholds
            classification, ndvi = classify_with_corrected_ndvi(band_info)
            
            # Calculate statistics
            total_pixels = classification.size
            stats = {}
            for class_id, class_info in CLASSES.items():
                class_pixels = np.sum(classification == class_id)
                stats[class_id] = {
                    'name': class_info['name'],
                    'percentage': (class_pixels / total_pixels) * 100,
                    'pixels': class_pixels
                }
            
            # Create CORRECT visualizations
            true_color = create_true_color_rgb(band_info)   # Red, Green, Blue
            false_color = create_false_color_rgb(band_info) # NIR, Red, Green
            
            # Display corrected results and save dashboard
            stats, dashboard_path = display_corrected_results(true_color, false_color, classification, stats, ndvi, img_file, image_type, src.transform)
            
            return classification, stats, dashboard_path
            
    except Exception as e:
        print(f"‚ùå Error processing {img_file}: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None

def find_and_group_image_files():
    """Find all image files and group them by fire ID."""
    if not os.path.exists(PREVIEW_DIR):
        return {}, {}
    
    all_files = [f for f in os.listdir(PREVIEW_DIR) 
                if f.endswith(('.tif', '.tiff')) and 'square_10km' in f]
    
    # Group files by fire ID
    fire_groups = {}
    for file in all_files:
        # Extract fire ID from filename (e.g., square_10km_allbands_before_001_20231015.tif)
        import re
        fire_match = re.search(r'_(\d{3})_', file)
        if fire_match:
            fire_id = fire_match.group(1)
            if fire_id not in fire_groups:
                fire_groups[fire_id] = {'before': [], 'after': []}
            
            if 'before' in file.lower():
                fire_groups[fire_id]['before'].append(file)
            elif 'after' in file.lower():
                fire_groups[fire_id]['after'].append(file)
    
    return all_files, fire_groups

def display_all_images_with_messages():
    """Display all images with proper grouping and clear 'no image' messages."""
    print("\n" + "="*80)
    print("üåç DISPLAYING ALL FIRE CLASSIFICATION RESULTS")
    print("="*80)
    
    # Find and group all image files
    all_files, fire_groups = find_and_group_image_files()
    
    if not all_files:
        print("‚ùå No square_10km TIFF files found in the directory.")
        print("\nüìÅ Available files in directory:")
        available_files = [f for f in os.listdir(PREVIEW_DIR) if f.endswith(('.tif', '.tiff'))]
        if available_files:
            for file in sorted(available_files):
                print(f"  - {file}")
        else:
            print("  No TIFF files found.")
        return 0
    
    print(f"üìä Found {len(all_files)} total images across {len(fire_groups)} fire locations")
    
    # Store results for comparison
    results = {}
    images_processed = 0
    
    # Process each fire location
    for fire_id in sorted(fire_groups.keys()):
        print(f"\n{'='*60}")
        print(f"üî• FIRE LOCATION {fire_id}")
        print(f"{'='*60}")
        
        # Process before images for this fire
        before_images = fire_groups[fire_id]['before']
        if before_images:
            print(f"\nüü¢ PROCESSING BEFORE FIRE IMAGES:")
            for img_file in sorted(before_images):
                classification, stats, dashboard_path = process_image_corrected(img_file, "BEFORE")
                if stats:
                    results[img_file] = {'stats': stats, 'type': 'before', 'dashboard': dashboard_path}
                    images_processed += 1
        else:
            print(f"\n‚ùå No BEFORE images available for fire {fire_id}")
        
        # Process after images for this fire
        after_images = fire_groups[fire_id]['after']
        if after_images:
            print(f"\nüî¥ PROCESSING AFTER FIRE IMAGES:")
            for img_file in sorted(after_images):
                classification, stats, dashboard_path = process_image_corrected(img_file, "AFTER")
                if stats:
                    results[img_file] = {'stats': stats, 'type': 'after', 'dashboard': dashboard_path}
                    images_processed += 1
        else:
            print(f"\n‚ùå No AFTER images available for fire {fire_id}")
    
    return images_processed, results

def compare_before_after(before_stats, after_stats, before_file, after_file):
    """Display comparison between before and after fire images."""
    print("\n" + "="*80)
    print("üî• FIRE IMPACT ANALYSIS - BEFORE vs AFTER")
    print("="*80)
    
    print(f"\nüìà CHANGE ANALYSIS:")
    print("-" * 70)
    print(f"{'Land Cover Type':20} {'Before %':>10} {'After %':>10} {'Change %':>12} {'Area Change':>15}")
    print("-" * 70)
    
    for class_id in sorted(CLASSES.keys()):
        class_name = CLASSES[class_id]['name']
        before_pct = before_stats[class_id]['percentage']
        after_pct = after_stats[class_id]['percentage']
        change_pct = after_pct - before_pct
        area_change = after_stats[class_id]['pixels'] - before_stats[class_id]['pixels']
        
        change_symbol = "‚Üë" if change_pct > 0.1 else "‚Üì" if change_pct < -0.1 else "‚âà"
        
        print(f"{class_name:20} {before_pct:>9.2f}% {after_pct:>9.2f}% {change_pct:>+11.2f}% {area_change:>+13,} px {change_symbol}")
    
    # Forest impact analysis
    before_forest = before_stats[2]['percentage'] + before_stats[3]['percentage']
    after_forest = after_stats[2]['percentage'] + after_stats[3]['percentage']
    forest_change = after_forest - before_forest
    
    before_vegetation = before_stats[1]['percentage'] + before_forest
    after_vegetation = after_stats[1]['percentage'] + after_forest
    vegetation_change = after_vegetation - before_vegetation
    
    print(f"\nüå≤ FIRE IMPACT SUMMARY:")
    print(f"   Total Forest Cover Change: {forest_change:+.2f}%")
    print(f"   Total Vegetation Change: {vegetation_change:+.2f}%")
    
    if forest_change < -5:
        print("   üö® SIGNIFICANT FOREST LOSS DETECTED")
    elif forest_change < -1:
        print("   ‚ö†Ô∏è  MODERATE FOREST LOSS DETECTED")
    elif forest_change > 1:
        print("   üå± FOREST GROWTH DETECTED")
    else:
        print("   ‚úÖ MINIMAL FOREST CHANGE")
    
    if vegetation_change < -5:
        print("   üö® SIGNIFICANT VEGETATION LOSS DETECTED")
    elif vegetation_change < -1:
        print("   ‚ö†Ô∏è  MODERATE VEGETATION LOSS DETECTED")
    else:
        print("   ‚úÖ MINIMAL VEGETATION CHANGE")
    
    print("="*80)

def main_corrected():
    """Main function with CORRECTED NDVI thresholds and complete looping."""
    print("‚úÖ üåç CORRECTED SENTINEL-2 CLASSIFICATION - BEFORE & AFTER FIRE")
    print("=" * 80)
    print("üéØ CORRECTED NDVI THRESHOLDS:")
    print("   - Fields/Agriculture: NDVI 0.1 to 0.3")
    print("   - Forests: NDVI ‚â• 0.3")
    print("   - Urban/Bare Soil: NDVI < 0.1")
    print("üìè SCALE BARS: 2 km (added to ALL images)")
    print(f"üíæ RESULTS SAVED TO: {CLASSIFIED_DIR}")
    print("=" * 80)
    
    if not os.path.exists(PREVIEW_DIR):
        print(f"‚ùå Directory not found: {PREVIEW_DIR}")
        print("Please run the extraction script first to generate image files.")
        return
    
    # Display all images with proper messaging
    images_processed, results = display_all_images_with_messages()
    
    # Print summary
    print(f"\n{'='*80}")
    print("‚úÖ PROCESSING COMPLETED")
    print("="*80)
    
    if images_processed > 0:
        print(f"üìä Successfully processed {images_processed} images:")
        
        # Count before and after images
        before_count = sum(1 for r in results.values() if r['type'] == 'before')
        after_count = sum(1 for r in results.values() if r['type'] == 'after')
        
        print(f"   - {before_count} BEFORE fire images")
        print(f"   - {after_count} AFTER fire images")
        print(f"üìè Scale bars added to ALL images (2 km)")
        print(f"üíæ All classification dashboards saved to: {CLASSIFIED_DIR}")
        
        # List saved files
        saved_files = [f for f in os.listdir(CLASSIFIED_DIR) if f.endswith('.png')]
        if saved_files:
            print(f"\nüìÅ Saved dashboard files:")
            for file in sorted(saved_files):
                print(f"   - {file}")
        
        # Show directory location
        print(f"\nüìÇ Classification Results Directory:")
        print(f"   {CLASSIFIED_DIR}")
        
        # Show comparison for first before/after pair if available
        before_files = [f for f in results.keys() if results[f]['type'] == 'before']
        after_files = [f for f in results.keys() if results[f]['type'] == 'after']
        
        if before_files and after_files:
            before_file = sorted(before_files)[0]
            after_file = sorted(after_files)[0]
            
            compare_before_after(
                results[before_file]['stats'],
                results[after_file]['stats'],
                before_file,
                after_file
            )
    else:
        print("‚ùå No images were successfully processed.")
        print("üí° Please check that:")
        print("   - The extraction script has been run successfully")
        print("   - TIFF files exist in the preview directory")
        print("   - Files follow the naming pattern: square_10km_allbands_[before/after]_[fire_id]_[date].tif")
        print(f"   - Source directory: {PREVIEW_DIR}")

# Run the corrected version
if __name__ == "__main__":
    main_corrected()

In [None]:
import os
import numpy as np
from PIL import Image as PILImage, ImageDraw
import rasterio
from IPython.display import display, HTML
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

# Configuration - UPDATED PATHS
PREVIEW_DIR = r'C:\Users\steph\Downloads\master_thesis\scripts\sentinel2_fire_images\preview'
NBR_DIR = r'C:\Users\steph\Downloads\master_thesis\scripts\sentinel2_fire_images\nbr_analysis'
os.makedirs(NBR_DIR, exist_ok=True)

# Burn severity classes based on NBR
BURN_CLASSES = {
    0: {'name': 'Enhanced Regrowth', 'color': (0, 100, 0), 'dNBR_range': '<-0.25'},
    1: {'name': 'Unburned', 'color': (50, 205, 50), 'dNBR_range': '-0.25 to -0.1'},
    2: {'name': 'Low Severity', 'color': (255, 255, 0), 'dNBR_range': '-0.1 to +0.1'},
    3: {'name': 'Moderate Severity', 'color': (255, 165, 0), 'dNBR_range': '+0.1 to +0.27'},
    4: {'name': 'High Severity', 'color': (255, 0, 0), 'dNBR_range': '>+0.27'}
}

def add_scale_bar_to_plot(ax, transform, scale_km=2):
    """Add a scale bar to matplotlib plot."""
    try:
        # Get image dimensions in data coordinates
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        
        # Calculate pixel size in meters using the transform
        pixel_size_m = abs(transform[0])  # meters per pixel
        
        # Calculate scale bar length in pixels
        scale_length_pixels = (scale_km * 1000) / pixel_size_m
        
        # Position in bottom right corner (5% from right, 5% from bottom)
        x_range = xlim[1] - xlim[0]
        y_range = ylim[1] - ylim[0]
        
        x_pos = xlim[1] - x_range * 0.05 - scale_length_pixels
        y_pos = ylim[0] + y_range * 0.05
        
        # Draw scale bar
        rect = plt.Rectangle((x_pos, y_pos), scale_length_pixels, y_range * 0.005,
                           facecolor='white', edgecolor='black', linewidth=2)
        ax.add_patch(rect)
        
        # Add scale text
        ax.text(x_pos + scale_length_pixels / 2, y_pos - y_range * 0.02,
               f'{scale_km} km', ha='center', va='top', color='white', 
               fontweight='bold', fontsize=10,
               bbox=dict(boxstyle="round,pad=0.3", facecolor='black', alpha=0.7))
        
    except Exception as e:
        print(f"  ‚ö† Could not add scale bar to plot: {e}")

def identify_sentinel2_bands_for_nbr(bands):
    """Identify Sentinel-2 bands needed for NBR calculation."""
    print(f"  - Number of bands: {len(bands)}")
    
    band_info = {}
    
    if len(bands) >= 12:
        # Standard 12-band Sentinel-2
        band_info = {
            'nir': bands[7],     # Band 8 - B08 (NIR)
            'swir2': bands[11]   # Band 12 - B12 (SWIR)
        }
        print("  - Identified as 12-band Sentinel-2 (using B8 for NIR, B12 for SWIR)")
        
    elif len(bands) >= 10:
        # Try to identify by wavelength characteristics
        band_means = [np.mean(band) for band in bands]
        
        # For NBR we need NIR (usually high reflectance) and SWIR (usually lower)
        # NIR is typically band 7 or 8, SWIR is typically band 11 or 12
        if len(bands) >= 8:
            band_info = {
                'nir': bands[7],   # Assume band 8 is NIR
                'swir2': bands[10] if len(bands) > 10 else bands[9]  # Try band 11 or 10
            }
            print("  - Using bands 8 (NIR) and 11/10 (SWIR) for NBR")
    else:
        # Fallback - try to find NIR and SWIR bands
        band_means = [np.mean(band) for band in bands]
        
        # NIR usually has highest reflectance in vegetation
        # SWIR usually has lower reflectance
        nir_idx = np.argmax(band_means)
        
        # Find a band with lower reflectance (likely SWIR)
        remaining = [i for i in range(len(bands)) if i != nir_idx]
        if remaining:
            swir_idx = remaining[np.argmin([band_means[i] for i in remaining])]
            
            band_info = {
                'nir': bands[nir_idx],
                'swir2': bands[swir_idx]
            }
            print(f"  - Identified by reflectance: B{nir_idx+1}=NIR, B{swir_idx+1}=SWIR")
        else:
            print("‚ùå Not enough bands for NBR calculation")
    
    return band_info

def calculate_nbr(nir, swir2):
    """Calculate Normalized Burn Ratio."""
    # Ensure we're working with float arrays
    nir = nir.astype(np.float32)
    swir2 = swir2.astype(np.float32)
    
    # Handle division by zero
    denominator = nir + swir2
    valid_mask = denominator > 0
    
    nbr = np.zeros_like(nir, dtype=np.float32)
    nbr[valid_mask] = (nir[valid_mask] - swir2[valid_mask]) / denominator[valid_mask]
    
    # NBR should be between -1 and 1
    nbr = np.clip(nbr, -1, 1)
    
    return nbr

def calculate_dnbr(nbr_before, nbr_after):
    """Calculate differenced NBR (dNBR) for burn severity."""
    dNBR = nbr_before - nbr_after
    return dNBR

def classify_burn_severity(dNBR):
    """Classify burn severity based on dNBR values."""
    # USGS standard dNBR burn severity classes
    classification = np.zeros(dNBR.shape, dtype=np.uint8)
    
    # Apply burn severity classification
    classification[dNBR < -0.25] = 0   # Enhanced Regrowth
    classification[(dNBR >= -0.25) & (dNBR < -0.1)] = 1   # Unburned
    classification[(dNBR >= -0.1) & (dNBR < 0.1)] = 2     # Low Severity
    classification[(dNBR >= 0.1) & (dNBR < 0.27)] = 3     # Moderate Severity
    classification[dNBR >= 0.27] = 4                      # High Severity
    
    return classification

def create_nbr_colormap(nbr):
    """Create a colored NBR visualization."""
    # Normalize NBR from -1 to 1 to 0 to 1
    nbr_normalized = (nbr + 1) / 2
    nbr_normalized = np.clip(nbr_normalized, 0, 1)
    
    # Use custom colormap for NBR
    colors = [(0, 0.5, 0), (0, 1, 0), (1, 1, 0), (1, 0.5, 0), (1, 0, 0)]  # Green to Red
    cmap = mcolors.LinearSegmentedColormap.from_list("nbr_cmap", colors)
    nbr_colored = (cmap(nbr_normalized) * 255).astype(np.uint8)[:, :, :3]
    return nbr_colored

def create_burn_severity_colormap(burn_classification):
    """Create a colored burn severity visualization."""
    height, width = burn_classification.shape
    severity_rgb = np.zeros((height, width, 3), dtype=np.uint8)
    
    for class_id, class_info in BURN_CLASSES.items():
        mask = burn_classification == class_id
        severity_rgb[mask] = class_info['color']
    
    return severity_rgb

def create_comprehensive_dashboard(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, transform):
    """Create a comprehensive dashboard figure with scale bars."""
    # Create visualizations
    nbr_before_vis = create_nbr_colormap(nbr_before)
    nbr_after_vis = create_nbr_colormap(nbr_after)
    severity_vis = create_burn_severity_colormap(burn_severity)
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 16))
    fig.suptitle(f'NBR Burn Severity Analysis: {before_file} vs {after_file}', 
                 fontsize=18, fontweight='bold', y=0.98)
    
    # Create grid layout
    gs = fig.add_gridspec(3, 4, hspace=0.3, wspace=0.3)
    
    # Plot 1: NBR Before Fire
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.imshow(nbr_before_vis)
    add_scale_bar_to_plot(ax1, transform)
    ax1.set_title('NBR - Before Fire\n(Green=Healthy, Red=Stressed)', fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Plot 2: NBR After Fire
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(nbr_after_vis)
    add_scale_bar_to_plot(ax2, transform)
    ax2.set_title('NBR - After Fire\n(Green=Healthy, Red=Stressed)', fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Plot 3: dNBR (Difference)
    ax3 = fig.add_subplot(gs[0, 2])
    im = ax3.imshow(dNBR, cmap='RdYlGn_r', vmin=-0.5, vmax=0.5)
    add_scale_bar_to_plot(ax3, transform)
    ax3.set_title('dNBR (Before - After)\n(Red=Burned, Green=Recovery)', fontsize=12, fontweight='bold')
    ax3.axis('off')
    plt.colorbar(im, ax=ax3, fraction=0.046, pad=0.04)
    
    # Plot 4: Burn Severity Classification
    ax4 = fig.add_subplot(gs[0, 3])
    ax4.imshow(severity_vis)
    add_scale_bar_to_plot(ax4, transform)
    ax4.set_title('Burn Severity Classification', fontsize=12, fontweight='bold')
    ax4.axis('off')
    
    # Plot 5: Burn Severity Statistics
    ax5 = fig.add_subplot(gs[1, :2])
    classes = [BURN_CLASSES[class_id]['name'] for class_id in sorted(stats.keys())]
    percentages = [stats[class_id]['percentage'] for class_id in sorted(stats.keys())]
    colors = [tuple(c/255 for c in BURN_CLASSES[class_id]['color']) for class_id in sorted(stats.keys())]
    
    bars = ax5.bar(classes, percentages, color=colors, edgecolor='black', alpha=0.7)
    ax5.set_title('Burn Severity Distribution (%)', fontsize=14, fontweight='bold')
    ax5.set_ylabel('Percentage (%)', fontsize=12)
    ax5.tick_params(axis='x', rotation=45)
    ax5.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, percentage in zip(bars, percentages):
        height = bar.get_height()
        ax5.text(bar.get_x() + bar.get_width()/2., height + 0.5,
                f'{percentage:.1f}%', ha='center', va='bottom', fontweight='bold', fontsize=10)
    
    # Plot 6: dNBR Histogram
    ax6 = fig.add_subplot(gs[1, 2:])
    ax6.hist(dNBR.flatten(), bins=50, color='orange', alpha=0.7, edgecolor='black')
    ax6.set_title('dNBR Value Distribution', fontsize=14, fontweight='bold')
    ax6.set_xlabel('dNBR Value', fontsize=12)
    ax6.set_ylabel('Pixel Count', fontsize=12)
    ax6.grid(True, alpha=0.3)
    
    # Add dNBR class thresholds
    thresholds = [-0.25, -0.1, 0.1, 0.27]
    colors = ['green', 'lightgreen', 'yellow', 'orange', 'red']
    threshold_labels = ['Enhanced Regrowth', 'Unburned', 'Low Severity', 'Moderate Severity', 'High Severity']
    
    for i, threshold in enumerate(thresholds):
        ax6.axvline(x=threshold, color=colors[i], linestyle='--', alpha=0.7, label=threshold_labels[i])
    
    ax6.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    # Plot 7: Text summary
    ax7 = fig.add_subplot(gs[2, :])
    ax7.axis('off')
    
    # Calculate summary statistics
    burned_classes = [2, 3, 4]
    total_burned = sum(stats[c]['pixels'] for c in burned_classes)
    total_burned_pct = (total_burned / burn_severity.size) * 100
    high_severity_pct = stats[4]['percentage']
    
    summary_text = [
        "BURN IMPACT SUMMARY:",
        f"   * Total Burned Area: {total_burned_pct:.1f}%",
        f"   * High Severity Burn: {high_severity_pct:.1f}%",
        f"   * Moderate Severity Burn: {stats[3]['percentage']:.1f}%",
        f"   * Low Severity Burn: {stats[2]['percentage']:.1f}%",
        f"   * Unburned Area: {stats[1]['percentage']:.1f}%",
        f"   * Enhanced Regrowth: {stats[0]['percentage']:.1f}%",
        "",
        "NBR STATISTICS:",
        f"   * Pre-fire NBR Mean: {nbr_before.mean():.3f}",
        f"   * Post-fire NBR Mean: {nbr_after.mean():.3f}",
        f"   * dNBR Mean: {dNBR.mean():.3f}",
        f"   * dNBR Range: [{dNBR.min():.3f}, {dNBR.max():.3f}]",
        "",
        "SCALE INFORMATION:",
        f"   * Scale bars: 2 km (all images)",
        f"   * Resolution: {abs(transform[0]):.1f} m/pixel"
    ]
    
    # Add fire impact assessment with text symbols
    if total_burned_pct > 50:
        impact = "ALERT: EXTENSIVE FIRE - Over 50% of area burned"
    elif total_burned_pct > 25:
        impact = "WARNING: MAJOR FIRE - 25-50% of area burned"
    elif total_burned_pct > 10:
        impact = "MODERATE FIRE - 10-25% of area burned"
    else:
        impact = "LIMITED FIRE - Less than 10% of area burned"
    
    if high_severity_pct > 20:
        damage = "ALERT: SEVERE DAMAGE - High severity burns over 20%"
    elif high_severity_pct > 10:
        damage = "WARNING: SIGNIFICANT DAMAGE - High severity burns 10-20%"
    else:
        damage = "MINOR DAMAGE - High severity burns under 10%"
    
    summary_text.extend([
        "",
        "FIRE IMPACT ASSESSMENT:",
        f"   * {impact}",
        f"   * {damage}"
    ])
    
    ax7.text(0.02, 0.95, "\n".join(summary_text), transform=ax7.transAxes, 
             fontsize=12, verticalalignment='top', fontfamily='monospace',
             bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgray", alpha=0.7))
    
    return fig

def save_dashboard_as_png(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, output_dir, transform):
    """Save only the dashboard as PNG."""
    base_name = f"{os.path.splitext(before_file)[0]}_vs_{os.path.splitext(after_file)[0]}"
    
    # Create and save comprehensive dashboard
    fig = create_comprehensive_dashboard(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, transform)
    dashboard_path = os.path.join(output_dir, f"{base_name}_dashboard.png")
    fig.savefig(dashboard_path, dpi=300, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    
    print(f"üíæ Dashboard saved as: {dashboard_path}")
    return dashboard_path

def display_nbr_analysis(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, transform):
    """Display comprehensive NBR analysis results and save only dashboard as PNG."""
    # Create and display the dashboard
    fig = create_comprehensive_dashboard(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, transform)
    
    # Display in notebook
    display(HTML(f"<h2>üî• NBR BURN SEVERITY ANALYSIS</h2>"))
    display(HTML(f"<h3>Before: {before_file} | After: {after_file}</h3>"))
    plt.show()
    
    # Save only dashboard as PNG
    dashboard_path = save_dashboard_as_png(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, NBR_DIR, transform)
    
    # Display detailed statistics in console WITH EMOJIS
    print("\n" + "="*80)
    print("üî• BURN SEVERITY STATISTICS")
    print("="*80)
    
    # Sort by burn severity (descending)
    sorted_stats = sorted(stats.items(), key=lambda x: x[1]['percentage'], reverse=True)
    
    for class_id, stat in sorted_stats:
        print(f"üè∑Ô∏è  {stat['name']:25} {stat['percentage']:6.2f}% ({stat['pixels']:>8,} pixels)")
    
    # Burn impact summary
    burned_classes = [2, 3, 4]  # Low, Moderate, High severity
    total_burned = sum(stats[c]['pixels'] for c in burned_classes)
    total_burned_pct = (total_burned / burn_severity.size) * 100
    high_severity_pct = stats[4]['percentage']
    
    print(f"\nüìä BURN IMPACT SUMMARY:")
    print(f"   Total Burned Area: {total_burned_pct:.1f}%")
    print(f"   High Severity Burn: {high_severity_pct:.1f}%")
    print(f"   Moderate Severity Burn: {stats[3]['percentage']:.1f}%")
    print(f"   Low Severity Burn: {stats[2]['percentage']:.1f}%")
    print(f"   Unburned Area: {stats[1]['percentage']:.1f}%")
    print(f"   Enhanced Regrowth: {stats[0]['percentage']:.1f}%")
    
    # NBR statistics
    print(f"\nüéØ NBR STATISTICS:")
    print(f"   Pre-fire NBR Range: [{nbr_before.min():.3f}, {nbr_before.max():.3f}]")
    print(f"   Pre-fire NBR Mean: {nbr_before.mean():.3f}")
    print(f"   Post-fire NBR Range: [{nbr_after.min():.3f}, {nbr_after.max():.3f}]")
    print(f"   Post-fire NBR Mean: {nbr_after.mean():.3f}")
    print(f"   dNBR Range: [{dNBR.min():.3f}, {dNBR.max():.3f}]")
    print(f"   dNBR Mean: {dNBR.mean():.3f}")
    
    # Scale information
    print(f"\nüìè SCALE INFORMATION:")
    print(f"   Scale bars: 2 km (added to all images)")
    print(f"   Resolution: {abs(transform[0]):.1f} meters/pixel")
    
    # Fire impact assessment WITH EMOJIS
    print(f"\nüî• FIRE IMPACT ASSESSMENT:")
    if total_burned_pct > 50:
        print("   üö® EXTENSIVE FIRE: Over 50% of area burned")
    elif total_burned_pct > 25:
        print("   ‚ö†Ô∏è  MAJOR FIRE: 25-50% of area burned")
    elif total_burned_pct > 10:
        print("   üî• MODERATE FIRE: 10-25% of area burned")
    else:
        print("   ‚úÖ LIMITED FIRE: Less than 10% of area burned")
    
    if high_severity_pct > 20:
        print("   üö® SEVERE DAMAGE: High severity burns over 20%")
    elif high_severity_pct > 10:
        print("   ‚ö†Ô∏è  SIGNIFICANT DAMAGE: High severity burns 10-20%")
    
    print("="*80)
    
    return dashboard_path

def process_nbr_analysis(before_file, after_file):
    """Process NBR analysis for before and after fire images."""
    before_path = os.path.join(PREVIEW_DIR, before_file)
    after_path = os.path.join(PREVIEW_DIR, after_file)
    
    try:
        print(f"\nüî• PROCESSING NBR ANALYSIS")
        print(f"   Before: {before_file}")
        print(f"   After: {after_file}")
        
        # Process before image
        with rasterio.open(before_path) as src_before:
            print(f"\nüîÑ Reading BEFORE image...")
            print(f"   Resolution: {abs(src_before.transform[0]):.1f} m/pixel")
            bands_before = []
            for i in range(1, src_before.count + 1):
                bands_before.append(src_before.read(i))
            
            band_info_before = identify_sentinel2_bands_for_nbr(bands_before)
            
            if 'nir' not in band_info_before or 'swir2' not in band_info_before:
                print("‚ùå Missing required bands for NBR calculation in BEFORE image")
                return None, None, None, None, None, None, None
            
            nbr_before = calculate_nbr(band_info_before['nir'], band_info_before['swir2'])
        
        # Process after image
        with rasterio.open(after_path) as src_after:
            print(f"\nüîÑ Reading AFTER image...")
            print(f"   Resolution: {abs(src_after.transform[0]):.1f} m/pixel")
            bands_after = []
            for i in range(1, src_after.count + 1):
                bands_after.append(src_after.read(i))
            
            band_info_after = identify_sentinel2_bands_for_nbr(bands_after)
            
            if 'nir' not in band_info_after or 'swir2' not in band_info_after:
                print("‚ùå Missing required bands for NBR calculation in AFTER image")
                return None, None, None, None, None, None, None
            
            nbr_after = calculate_nbr(band_info_after['nir'], band_info_after['swir2'])
        
        # Calculate dNBR and burn severity
        dNBR = calculate_dnbr(nbr_before, nbr_after)
        burn_severity = classify_burn_severity(dNBR)
        
        # Calculate statistics
        total_pixels = burn_severity.size
        stats = {}
        for class_id, class_info in BURN_CLASSES.items():
            class_pixels = np.sum(burn_severity == class_id)
            stats[class_id] = {
                'name': class_info['name'],
                'percentage': (class_pixels / total_pixels) * 100,
                'pixels': class_pixels
            }
        
        # Display comprehensive analysis and save only dashboard
        dashboard_path = display_nbr_analysis(nbr_before, nbr_after, dNBR, burn_severity, stats, before_file, after_file, src_before.transform)
        
        return nbr_before, nbr_after, dNBR, burn_severity, stats, dashboard_path, src_before.transform
        
    except Exception as e:
        print(f"‚ùå Error processing NBR analysis: {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None, None, None, None

def find_and_group_image_files():
    """Find all image files and group them by fire ID."""
    if not os.path.exists(PREVIEW_DIR):
        return {}, {}
    
    all_files = [f for f in os.listdir(PREVIEW_DIR) 
                if f.endswith(('.tif', '.tiff')) and 'square_10km' in f]
    
    # Group files by fire ID
    fire_groups = {}
    for file in all_files:
        # Extract fire ID from filename (e.g., square_10km_allbands_before_001_20231015.tif)
        import re
        fire_match = re.search(r'_(\d{3})_', file)
        if fire_match:
            fire_id = fire_match.group(1)
            if fire_id not in fire_groups:
                fire_groups[fire_id] = {'before': [], 'after': []}
            
            if 'before' in file.lower():
                fire_groups[fire_id]['before'].append(file)
            elif 'after' in file.lower():
                fire_groups[fire_id]['after'].append(file)
    
    return all_files, fire_groups

def display_all_nbr_analyses_with_messages():
    """Display all NBR analyses with proper grouping and clear 'no image' messages."""
    print("\n" + "="*80)
    print("üî• DISPLAYING ALL NBR BURN SEVERITY ANALYSES")
    print("="*80)
    
    # Find and group all image files
    all_files, fire_groups = find_and_group_image_files()
    
    if not all_files:
        print("‚ùå No square_10km TIFF files found in the directory.")
        print("\nüìÅ Available files in directory:")
        available_files = [f for f in os.listdir(PREVIEW_DIR) if f.endswith(('.tif', '.tiff'))]
        if available_files:
            for file in sorted(available_files):
                print(f"  - {file}")
        else:
            print("  No TIFF files found.")
        return 0, {}
    
    print(f"üìä Found {len(all_files)} total images across {len(fire_groups)} fire locations")
    
    # Store results
    results = {}
    analyses_completed = 0
    
    # Process each fire location
    for fire_id in sorted(fire_groups.keys()):
        print(f"\n{'='*60}")
        print(f"üî• FIRE LOCATION {fire_id}")
        print(f"{'='*60}")
        
        # Check for before/after pairs
        before_images = fire_groups[fire_id]['before']
        after_images = fire_groups[fire_id]['after']
        
        if not before_images:
            print(f"\n‚ùå No BEFORE images available for fire {fire_id} - cannot perform NBR analysis")
            continue
            
        if not after_images:
            print(f"\n‚ùå No AFTER images available for fire {fire_id} - cannot perform NBR analysis")
            continue
        
        print(f"\nüîÑ PROCESSING NBR ANALYSIS FOR FIRE {fire_id}")
        print(f"   Before images: {len(before_images)}")
        print(f"   After images: {len(after_images)}")
        
        # Process first before/after pair for each fire
        before_file = sorted(before_images)[0]
        after_file = sorted(after_images)[0]
        
        nbr_before, nbr_after, dNBR, burn_severity, stats, dashboard_path, transform = process_nbr_analysis(before_file, after_file)
        
        if burn_severity is not None:
            results[fire_id] = {
                'before_file': before_file,
                'after_file': after_file,
                'stats': stats,
                'dashboard': dashboard_path,
                'nbr_before': nbr_before,
                'nbr_after': nbr_after,
                'dNBR': dNBR
            }
            analyses_completed += 1
            print(f"‚úÖ NBR analysis completed for fire {fire_id}")
        else:
            print(f"‚ùå NBR analysis failed for fire {fire_id}")
    
    return analyses_completed, results

def main_nbr_analysis():
    """Main function for NBR burn severity analysis with complete looping."""
    print("‚úÖ üî• NORMALIZED BURN RATIO (NBR) ANALYSIS")
    print("=" * 80)
    print("üéØ Burn Severity Assessment using dNBR (differenced NBR)")
    print("üéØ Requires both BEFORE and AFTER fire images")
    print("üìè Scale bars: 2 km (added to ALL images)")
    print(f"üíæ Results saved to: {NBR_DIR}")
    print("=" * 80)
    
    if not os.path.exists(PREVIEW_DIR):
        print(f"‚ùå Directory not found: {PREVIEW_DIR}")
        print("Please run the extraction script first to generate image files.")
        return
    
    # Display all NBR analyses with proper messaging
    analyses_completed, results = display_all_nbr_analyses_with_messages()
    
    # Print summary
    print(f"\n{'='*80}")
    print("‚úÖ NBR ANALYSIS COMPLETED")
    print("="*80)
    
    if analyses_completed > 0:
        print(f"üìä Successfully completed {analyses_completed} NBR analyses:")
        
        for fire_id, result in results.items():
            burned_classes = [2, 3, 4]
            total_burned = sum(result['stats'][c]['pixels'] for c in burned_classes)
            total_burned_pct = (total_burned / result['dNBR'].size) * 100
            
            print(f"   üî• Fire {fire_id}: {total_burned_pct:.1f}% burned area")
            print(f"      - Before: {result['before_file']}")
            print(f"      - After: {result['after_file']}")
            print(f"      - Dashboard: {os.path.basename(result['dashboard'])}")
        
        print(f"üìè Scale bars added to ALL images (2 km)")
        print(f"üíæ All NBR dashboards saved to: {NBR_DIR}")
        
        # List saved files
        saved_files = [f for f in os.listdir(NBR_DIR) if f.endswith('.png')]
        if saved_files:
            print(f"\nüìÅ Saved NBR dashboard files:")
            for file in sorted(saved_files):
                print(f"   - {file}")
        
        # Show directory location
        print(f"\nüìÇ NBR Analysis Directory:")
        print(f"   {NBR_DIR}")
        
    else:
        print("‚ùå No NBR analyses were completed.")
        print("üí° Please check that:")
        print("   - Both BEFORE and AFTER images exist for at least one fire location")
        print("   - Images contain the required NIR and SWIR bands")
        print("   - Files follow the naming pattern: square_10km_allbands_[before/after]_[fire_id]_[date].tif")
        print(f"   - Source directory: {PREVIEW_DIR}")

# Run the NBR analysis
if __name__ == "__main__":
    main_nbr_analysis()