# Bangalore Tessera Embeddings Interactive Viewer

Interactive viewer with 9 synchronized maps:
- Center: 2024 Satellite RGB
- Surrounding: Tessera embeddings (2017-2024)
- Click to place labeled pins for training data collection


In [None]:
import numpy as np
import rasterio
from rasterio.warp import transform_bounds
from ipyleaflet import Map, ImageOverlay, Marker, MarkerCluster
from ipywidgets import Output, VBox, HBox, Text, Button, Dropdown, Label, IntSlider
from pathlib import Path
import json
from io import BytesIO
import base64
from PIL import Image
import matplotlib.pyplot as plt

In [None]:
# Configuration
PYRAMIDS_DIR = Path("pyramids")
YEARS = list(range(2017, 2025))
LABELS_FILE = Path("training_labels.json")

# Map layout: 3x3 grid
# Top row: 2017, 2018, 2019
# Mid row: 2020, satellite (center), 2021  
# Bot row: 2022, 2023, 2024
MAP_LAYOUT = [
    [2017, 2018, 2019],
    [2020, 'satellite', 2021],
    [2022, 2023, 2024]
]

In [None]:
class BangaloreViewer:
    def __init__(self, initial_zoom=4):
        """Initialize the 9-map viewer with synchronized zoom/pan."""
        
        # Bangalore center coordinates
        self.center = [12.97, 77.59]
        self.initial_zoom = initial_zoom
        self.current_zoom_level = initial_zoom  # 0-7 for pyramid levels
        
        # Training labels: {map_id: [(lat, lon, label), ...]}
        self.labels = {}
        for year in YEARS:
            self.labels[str(year)] = []
        self.labels['satellite'] = []
        
        # Markers: {map_id: {(lat, lon): marker}}
        self.markers = {}
        for year in YEARS:
            self.markers[str(year)] = {}
        self.markers['satellite'] = {}
        
        # Create maps
        self.maps = {}
        self.overlays = {}
        self.create_maps()
        
        # Load existing labels if available
        self.load_labels()
        
        # Current label for new pins
        self.current_label = "unlabeled"
    
    def load_image_as_overlay(self, image_path, map_id):
        """Load a GeoTIFF and create an ImageOverlay for ipyleaflet."""
        with rasterio.open(image_path) as src:
            # Read image data
            if src.count == 1:
                # Single band - convert to RGB
                band = src.read(1)
                rgb = np.stack([band, band, band], axis=0)
            else:
                # Already RGB
                rgb = src.read([1, 2, 3])
            
            # Get bounds in lat/lon
            bounds = src.bounds
            if src.crs != 'EPSG:4326':
                bounds = transform_bounds(src.crs, 'EPSG:4326', *bounds)
            
            # Convert to PIL Image
            rgb_transposed = np.transpose(rgb, (1, 2, 0))
            img = Image.fromarray(rgb_transposed, mode='RGB')
            
            # Convert to base64 for ipyleaflet
            buffer = BytesIO()
            img.save(buffer, format='PNG')
            img_str = base64.b64encode(buffer.getvalue()).decode()
            img_url = f'data:image/png;base64,{img_str}'
            
            # Create bounds for ipyleaflet: [[south, west], [north, east]]
            bounds_leaflet = [[bounds[1], bounds[0]], [bounds[3], bounds[2]]]
            
            return ImageOverlay(url=img_url, bounds=bounds_leaflet)
    
    def create_maps(self):
        """Create all 9 maps in a 3x3 grid."""
        for row in MAP_LAYOUT:
            for map_id in row:
                # Create map
                m = Map(
                    center=self.center,
                    zoom=self.initial_zoom + 8,  # Leaflet zoom (ipyleaflet uses web mercator zoom)
                    scroll_wheel_zoom=True,
                    dragging=True
                )
                
                # Load appropriate image for this map
                if map_id == 'satellite':
                    img_path = PYRAMIDS_DIR / 'satellite' / f'level_{self.current_zoom_level}.tif'
                else:
                    img_path = PYRAMIDS_DIR / str(map_id) / f'level_{self.current_zoom_level}.tif'
                
                if img_path.exists():
                    overlay = self.load_image_as_overlay(img_path, map_id)
                    m.add_layer(overlay)
                    self.overlays[str(map_id)] = overlay
                
                # Add click handler for pin placement
                m.on_interaction(lambda **kwargs: self.handle_map_click(map_id, **kwargs))
                
                self.maps[str(map_id)] = m
        
        # Link all maps for synchronized zoom/pan
        self.link_maps()
    
    def link_maps(self):
        """Link all maps so they zoom/pan together."""
        from ipywidgets import jslink
        
        # Get reference map (center/satellite)
        ref_map = self.maps['satellite']
        
        # Link all other maps to reference
        for map_id, m in self.maps.items():
            if map_id != 'satellite':
                jslink((ref_map, 'center'), (m, 'center'))
                jslink((ref_map, 'zoom'), (m, 'zoom'))
    
    def handle_map_click(self, map_id, **kwargs):
        """Handle click on map to place/remove pins."""
        if kwargs.get('type') == 'click':
            coords = kwargs.get('coordinates')
            if coords:
                lat, lon = coords
                
                # Check if clicking existing marker (to remove)
                marker_key = (lat, lon)
                if marker_key in self.markers[str(map_id)]:
                    self.remove_marker(map_id, lat, lon)
                else:
                    self.add_marker(map_id, lat, lon, self.current_label)
    
    def add_marker(self, map_id, lat, lon, label):
        """Add a labeled marker to the map."""
        marker = Marker(
            location=(lat, lon),
            draggable=False,
            title=label
        )
        
        self.maps[str(map_id)].add_layer(marker)
        self.markers[str(map_id)][(lat, lon)] = marker
        self.labels[str(map_id)].append((lat, lon, label))
        
        print(f"Added '{label}' marker at ({lat:.4f}, {lon:.4f}) on {map_id} map")
    
    def remove_marker(self, map_id, lat, lon):
        """Remove a marker from the map."""
        marker_key = (lat, lon)
        if marker_key in self.markers[str(map_id)]:
            marker = self.markers[str(map_id)][marker_key]
            self.maps[str(map_id)].remove_layer(marker)
            del self.markers[str(map_id)][marker_key]
            
            # Remove from labels
            self.labels[str(map_id)] = [
                (la, lo, lab) for la, lo, lab in self.labels[str(map_id)]
                if not (abs(la - lat) < 0.0001 and abs(lo - lon) < 0.0001)
            ]
            
            print(f"Removed marker at ({lat:.4f}, {lon:.4f}) from {map_id} map")
    
    def save_labels(self):
        """Save labels to JSON file."""
        with open(LABELS_FILE, 'w') as f:
            json.dump(self.labels, f, indent=2)
        print(f"Saved labels to {LABELS_FILE}")
        
        # Print summary
        total = sum(len(v) for v in self.labels.values())
        print(f"Total labels: {total}")
        for map_id, labels in self.labels.items():
            if labels:
                print(f"  {map_id}: {len(labels)} labels")
    
    def load_labels(self):
        """Load labels from JSON file if it exists."""
        if LABELS_FILE.exists():
            with open(LABELS_FILE, 'r') as f:
                loaded_labels = json.load(f)
            
            # Restore markers
            for map_id, labels in loaded_labels.items():
                for lat, lon, label in labels:
                    self.add_marker(map_id, lat, lon, label)
            
            print(f"Loaded {sum(len(v) for v in loaded_labels.values())} labels from {LABELS_FILE}")
    
    def display(self):
        """Display the 9-map grid with controls."""
        # Controls
        label_input = Text(
            value='building',
            placeholder='Enter label',
            description='Label:',
            disabled=False
        )
        
        def update_label(change):
            self.current_label = change['new']
        
        label_input.observe(update_label, names='value')
        
        save_btn = Button(description='Save Labels', button_style='success')
        save_btn.on_click(lambda b: self.save_labels())
        
        clear_btn = Button(description='Clear All', button_style='danger')
        
        def clear_all(b):
            for map_id in self.labels.keys():
                for lat, lon, _ in list(self.labels[map_id]):
                    self.remove_marker(map_id, lat, lon)
        
        clear_btn.on_click(clear_all)
        
        controls = HBox([label_input, save_btn, clear_btn])
        
        # Create map grid
        rows = []
        for row in MAP_LAYOUT:
            maps_row = []
            for map_id in row:
                map_widget = self.maps[str(map_id)]
                # Add label above each map
                label_widget = Label(value=f"{map_id}")
                map_with_label = VBox([label_widget, map_widget])
                maps_row.append(map_with_label)
            rows.append(HBox(maps_row))
        
        map_grid = VBox(rows)
        
        return VBox([controls, map_grid])

In [None]:
# Create and display the viewer
viewer = BangaloreViewer(initial_zoom=4)
viewer.display()

## Usage Instructions

1. **Zoom and Pan**: Interact with the center (satellite) map - all other maps will follow
2. **Add Labels**: 
   - Enter a label name in the text box (e.g., "building", "road", "vegetation")
   - Click on any map to place a marker with that label
3. **Remove Labels**: Click on an existing marker to remove it
4. **Save**: Click "Save Labels" to save all markers to `training_labels.json`
5. **Clear**: Click "Clear All" to remove all markers

The labels are saved with coordinates and can be used to extract embeddings for training classifiers.
