# GEE-Style Image Explorer

Google Earth Engine style interface for exploring your Planet image collection.

In [None]:
import numpy as np
import rasterio
import matplotlib.pyplot as plt
import glob
import os
import ipywidgets as widgets
from IPython.display import display
import re
from datetime import datetime

plt.rcParams['figure.dpi'] = 100

## 1. Create ImageCollection (Load all images)

In [None]:
class ImageCollection:
    """
    GEE-style ImageCollection for local raster files
    """
    def __init__(self, pattern):
        self.images = {}
        self.load_collection(pattern)
    
    def load_collection(self, pattern):
        """Load all images matching pattern"""
        files = glob.glob(pattern)
        
        for file in files:
            # Extract date from filename
            basename = os.path.basename(file)
            date_match = re.search(r'(\d{4}-\d{2}-\d{2})', basename)
            
            if date_match:
                date = date_match.group(1)
                
                # Load image info
                with rasterio.open(file) as src:
                    self.images[date] = {
                        'path': file,
                        'bands': src.count,
                        'shape': (src.height, src.width),
                        'dtype': src.dtypes[0],
                        'crs': src.crs
                    }
        
        print(f"📡 ImageCollection loaded: {len(self.images)} images")
        for date, info in sorted(self.images.items()):
            print(f"  {date}: {info['bands']} bands, {info['shape']}")
    
    def select(self, date, band=None):
        """Select image by date and optionally band"""
        if date not in self.images:
            raise ValueError(f"Date {date} not found in collection")
        
        path = self.images[date]['path']
        
        with rasterio.open(path) as src:
            if band is not None:
                return src.read(band + 1)  # rasterio is 1-indexed
            else:
                return src.read()  # All bands
    
    def dates(self):
        """Get all available dates"""
        return sorted(self.images.keys())
    
    def band_names(self, date):
        """Get band names for a date"""
        n_bands = self.images[date]['bands']
        if n_bands == 4:
            return ['Red', 'Green', 'Blue', 'NIR']
        elif n_bands == 1:
            return ['NDWI']
        else:
            return [f'Band_{i+1}' for i in range(n_bands)]

In [None]:
# Load your ImageCollection
# Option 1: Planet mosaics (RGB+NIR)
ic_planet = ImageCollection("testimages25/*_mosaic.tif")

# Option 2: NDWI images (single band)
ic_ndwi = ImageCollection("outputs_cleaned/*_ndwi.tif")

# Choose which collection to use
ic = ic_planet  # Change to ic_ndwi for NDWI images

## 2. GEE-Style Visualization Interface

In [None]:
class Visualizer:
    """
    GEE-style Map.addLayer() equivalent
    """
    def __init__(self, image_collection):
        self.ic = image_collection
        self.setup_widgets()
    
    def setup_widgets(self):
        """Create GEE-style interface widgets"""
        dates = self.ic.dates()
        
        # Red channel selection
        self.red_date = widgets.Dropdown(
            options=dates,
            value=dates[0] if dates else None,
            description='Red Date:',
            style={'description_width': 'initial'}
        )
        
        self.red_band = widgets.Dropdown(
            options=[],
            description='Red Band:',
            style={'description_width': 'initial'}
        )
        
        # Green channel selection
        self.green_date = widgets.Dropdown(
            options=dates,
            value=dates[len(dates)//2] if len(dates) > 1 else dates[0],
            description='Green Date:',
            style={'description_width': 'initial'}
        )
        
        self.green_band = widgets.Dropdown(
            options=[],
            description='Green Band:',
            style={'description_width': 'initial'}
        )
        
        # Blue channel selection
        self.blue_date = widgets.Dropdown(
            options=dates,
            value=dates[-1] if dates else None,
            description='Blue Date:',
            style={'description_width': 'initial'}
        )
        
        self.blue_band = widgets.Dropdown(
            options=[],
            description='Blue Band:',
            style={'description_width': 'initial'}
        )
        
        # Visualization parameters (GEE style)
        self.vis_min = widgets.FloatSlider(
            value=0.02,
            min=0,
            max=1,
            step=0.01,
            description='Min:'
        )
        
        self.vis_max = widgets.FloatSlider(
            value=0.98,
            min=0,
            max=1,
            step=0.01,
            description='Max:'
        )
        
        # Update band options when date changes
        self.red_date.observe(self.update_red_bands, names='value')
        self.green_date.observe(self.update_green_bands, names='value')
        self.blue_date.observe(self.update_blue_bands, names='value')
        
        # Initialize band options
        if dates:
            self.update_red_bands(None)
            self.update_green_bands(None)
            self.update_blue_bands(None)
    
    def update_red_bands(self, change):
        date = self.red_date.value
        bands = self.ic.band_names(date)
        self.red_band.options = [(name, i) for i, name in enumerate(bands)]
        self.red_band.value = 0
    
    def update_green_bands(self, change):
        date = self.green_date.value
        bands = self.ic.band_names(date)
        self.green_band.options = [(name, i) for i, name in enumerate(bands)]
        self.green_band.value = 0
    
    def update_blue_bands(self, change):
        date = self.blue_date.value
        bands = self.ic.band_names(date)
        self.blue_band.options = [(name, i) for i, name in enumerate(bands)]
        self.blue_band.value = 0
    
    def display_controls(self):
        """Display GEE-style control panel"""
        print("🗺️  Map.addLayer() Controls:")
        
        red_box = widgets.VBox([
            widgets.HTML("<b style='color:red'>Red Channel</b>"),
            self.red_date,
            self.red_band
        ])
        
        green_box = widgets.VBox([
            widgets.HTML("<b style='color:green'>Green Channel</b>"),
            self.green_date,
            self.green_band
        ])
        
        blue_box = widgets.VBox([
            widgets.HTML("<b style='color:blue'>Blue Channel</b>"),
            self.blue_date,
            self.blue_band
        ])
        
        vis_box = widgets.VBox([
            widgets.HTML("<b>Visualization Parameters</b>"),
            self.vis_min,
            self.vis_max
        ])
        
        display(widgets.HBox([red_box, green_box, blue_box, vis_box]))
        
        # Add layer button
        add_layer_btn = widgets.Button(
            description='🗺️ Map.addLayer()',
            button_style='success',
            layout=widgets.Layout(width='200px', height='40px')
        )
        add_layer_btn.on_click(self.add_layer)
        
        display(add_layer_btn)
    
    def add_layer(self, button):
        """GEE-style Map.addLayer() function"""
        try:
            # Get selected data
            red_data = self.ic.select(self.red_date.value, self.red_band.value)
            green_data = self.ic.select(self.green_date.value, self.green_band.value)
            blue_data = self.ic.select(self.blue_date.value, self.blue_band.value)
            
            # Normalize
            red_norm = self.normalize_band(red_data)
            green_norm = self.normalize_band(green_data)
            blue_norm = self.normalize_band(blue_data)
            
            # Create RGB
            rgb = np.stack([red_norm, green_norm, blue_norm], axis=2)
            
            # Apply vis parameters
            rgb = np.clip((rgb - self.vis_min.value) / (self.vis_max.value - self.vis_min.value), 0, 1)
            
            # Plot (GEE style)
            fig, axes = plt.subplots(1, 4, figsize=(20, 5))
            
            # Individual channels
            axes[0].imshow(red_norm, cmap='Reds')
            axes[0].set_title(f'Red: {self.red_date.value}\n{self.ic.band_names(self.red_date.value)[self.red_band.value]}')
            axes[0].axis('off')
            
            axes[1].imshow(green_norm, cmap='Greens')
            axes[1].set_title(f'Green: {self.green_date.value}\n{self.ic.band_names(self.green_date.value)[self.green_band.value]}')
            axes[1].axis('off')
            
            axes[2].imshow(blue_norm, cmap='Blues')
            axes[2].set_title(f'Blue: {self.blue_date.value}\n{self.ic.band_names(self.blue_date.value)[self.blue_band.value]}')
            axes[2].axis('off')
            
            # RGB composite
            axes[3].imshow(rgb)
            axes[3].set_title(f'RGB Composite\nvis: [{self.vis_min.value:.2f}, {self.vis_max.value:.2f}]')
            axes[3].axis('off')
            
            plt.tight_layout()
            plt.show()
            
            # Print GEE-style info
            print(f"✅ Layer added successfully!")
            print(f"   RGB: {self.red_date.value}/{self.green_date.value}/{self.blue_date.value}")
            print(f"   Bands: {self.red_band.value}/{self.green_band.value}/{self.blue_band.value}")
            print(f"   Vis params: min={self.vis_min.value}, max={self.vis_max.value}")
            
        except Exception as e:
            print(f"❌ Error adding layer: {e}")
    
    def normalize_band(self, band):
        """Normalize band using percentiles"""
        band = band.astype(np.float32)
        p2, p98 = np.percentile(band, [2, 98])
        return np.clip((band - p2) / (p98 - p2), 0, 1)

# Create visualizer
viz = Visualizer(ic)

## 3. GEE-Style Interface

Just like in Google Earth Engine!

In [None]:
# Display the GEE-style interface
viz.display_controls()

## 4. Quick Presets (GEE style)

In [None]:
def preset_temporal_rgb():
    """Temporal RGB: Early/Mid/Late season"""
    dates = viz.ic.dates()
    if len(dates) >= 3:
        viz.red_date.value = dates[0]  # Early season
        viz.green_date.value = dates[len(dates)//2]  # Mid season
        viz.blue_date.value = dates[-1]  # Late season
        print(f"🗓️ Temporal RGB: {dates[0]} / {dates[len(dates)//2]} / {dates[-1]}")

def preset_ndwi_evolution():
    """NDWI temporal evolution"""
    dates = viz.ic.dates()
    if len(dates) >= 3:
        viz.red_date.value = dates[0]
        viz.green_date.value = dates[len(dates)//2]
        viz.blue_date.value = dates[-1]
        # Set all to NDWI band (0 for single-band NDWI images)
        viz.red_band.value = 0
        viz.green_band.value = 0  
        viz.blue_band.value = 0
        print(f"💧 NDWI Evolution: {dates[0]} / {dates[len(dates)//2]} / {dates[-1]}")

def preset_winter_spring_summer():
    """Winter/Spring/Summer if available"""
    # Try to find seasonal dates
    dates = viz.ic.dates()
    winter = [d for d in dates if d.endswith(('01-08', '02-08'))]
    spring = [d for d in dates if d.endswith(('04-25', '05-25'))]
    summer = [d for d in dates if d.endswith(('07-06', '08-21'))]
    
    if winter and spring and summer:
        viz.red_date.value = winter[0]
        viz.green_date.value = spring[0]
        viz.blue_date.value = summer[0]
        print(f"❄️🌱☀️ Seasonal: {winter[0]} / {spring[0]} / {summer[0]}")
    else:
        preset_temporal_rgb()

# Preset buttons
preset_btns = [
    widgets.Button(description='⏰ Temporal RGB', button_style='info'),
    widgets.Button(description='💧 NDWI Evolution', button_style='info'),
    widgets.Button(description='🗓️ Winter/Spring/Summer', button_style='info')
]

preset_btns[0].on_click(lambda b: preset_temporal_rgb())
preset_btns[1].on_click(lambda b: preset_ndwi_evolution())
preset_btns[2].on_click(lambda b: preset_winter_spring_summer())

print("🚀 Quick Presets (just like GEE!):")
display(widgets.HBox(preset_btns))

## 5. Export (GEE style)

In [None]:
def export_image():
    """Export.image.toDrive() equivalent"""
    try:
        # Get current composite
        red_data = viz.ic.select(viz.red_date.value, viz.red_band.value)
        green_data = viz.ic.select(viz.green_date.value, viz.green_band.value)
        blue_data = viz.ic.select(viz.blue_date.value, viz.blue_band.value)
        
        # Normalize and create RGB
        red_norm = viz.normalize_band(red_data)
        green_norm = viz.normalize_band(green_data)
        blue_norm = viz.normalize_band(blue_data)
        rgb = np.stack([red_norm, green_norm, blue_norm], axis=2)
        
        # Apply vis parameters
        rgb = np.clip((rgb - viz.vis_min.value) / (viz.vis_max.value - viz.vis_min.value), 0, 1)
        
        # Export filename
        filename = f"temporal_composite_{viz.red_date.value}_{viz.green_date.value}_{viz.blue_date.value}.png"
        
        # Save
        plt.figure(figsize=(12, 8))
        plt.imshow(rgb)
        plt.axis('off')
        plt.title(f'Temporal Composite: R={viz.red_date.value}, G={viz.green_date.value}, B={viz.blue_date.value}')
        plt.savefig(filename, dpi=300, bbox_inches='tight', pad_inches=0)
        plt.close()
        
        print(f"📤 Export.image.toDrive() complete!")
        print(f"   Saved: {filename}")
        
    except Exception as e:
        print(f"❌ Export failed: {e}")

# Export button
export_btn = widgets.Button(
    description='📤 Export.image.toDrive()',
    button_style='warning',
    layout=widgets.Layout(width='200px')
)
export_btn.on_click(lambda b: export_image())

display(export_btn)