In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.neighbors import KNeighborsClassifier
import ipyleaflet
from ipyleaflet import TileLayer, Map, ImageOverlay, CircleMarker, LayersControl
from ipywidgets import *
from geotessera import GeoTessera
from sklearn.decomposition import PCA
import time
from functools import partial
import io
import base64
import json
import rasterio
from rasterio.io import MemoryFile
from rasterio.merge import merge
from rasterio.warp import calculate_default_transform, reproject
from rasterio.enums import Resampling


In [2]:
MIN_LON, MAX_LON = -0.10, 0.30
MIN_LAT, MAX_LAT = 52.00, 52.20
# This is in increments of 0.1, and at minimum there needs to be a 0.1 by 0.1 deg bbox
print("Area of Interest defined.")

Area of Interest defined.


In [3]:
# -- 2. FETCH, STITCH, AND PREPARE PCA VISUALIZATION --

# Initialize GeoTessera
tessera = GeoTessera()

# --- CONFIGURATION ---
roi_bounds = (MIN_LON, MIN_LAT, MAX_LON, MAX_LAT)
target_year = 2024
# Define a common Target CRS
target_crs = 'EPSG:4326'

print(f"Searching for tiles in ROI: {roi_bounds} for year {target_year}")

# --- Find tiles in ROI ---
tiles_to_merge = []
for year in tessera.get_available_years():
    tessera._ensure_year_loaded(year)
    
for emb_year, lat, lon in tessera.list_available_embeddings():
    if emb_year != target_year:
        continue
    tile_min_lon, tile_min_lat, tile_max_lon, tile_max_lat = lon, lat, lon, lat
    if (tile_min_lon < roi_bounds[2] and tile_max_lon > roi_bounds[0] and
        tile_min_lat < roi_bounds[3] and tile_max_lat > roi_bounds[1]):
        tiles_to_merge.append((lat, lon))

if not tiles_to_merge:
    raise ValueError(f"No embedding tiles found for the specified ROI in year {target_year}")
print(f"Found {len(tiles_to_merge)} tiles to merge.")

# --- PROCESS AND REPROJECT EACH TILE IN MEMORY ---
src_files_to_merge = []
print("\nPass 1: Creating reprojected, georeferenced files in memory...")
for i, (lat, lon) in enumerate(tiles_to_merge):
    print(f"  Processing tile {i+1}/{len(tiles_to_merge)} at ({lat:.2f}, {lon:.2f})...")
    try:
        # 1. Fetch source data and metadata
        embedding = tessera.get_embedding(lat, lon, year=target_year) # (H, W, C)
        landmask_path = tessera._fetch_landmask(lat, lon, progressbar=False)
        with rasterio.open(landmask_path) as src:
            src_crs = src.crs
            src_transform = src.transform
            src_height, src_width = src.height, src.width
            src_bounds = src.bounds

        # 2. Calculate the parameters for reprojection
        dst_transform, dst_width, dst_height = calculate_default_transform(
            src_crs, target_crs, src_width, src_height, *src_bounds
        )

        reprojected_embedding = np.empty((embedding.shape[2], dst_height, dst_width), dtype=embedding.dtype)

        # 4. Perform the reprojection band by band
        for band_idx in range(embedding.shape[2]):
            reproject(
                source=embedding[:, :, band_idx],
                destination=reprojected_embedding[band_idx],
                src_transform=src_transform,
                src_crs=src_crs,
                dst_transform=dst_transform,
                dst_crs=target_crs,
                resampling=Resampling.bilinear
            )

        # 5. Write the reprojected data to an in-memory file
        memfile = MemoryFile()
        with memfile.open(
            driver='GTiff',
            height=dst_height,
            width=dst_width,
            count=embedding.shape[2],
            dtype=embedding.dtype,
            crs=target_crs,
            transform=dst_transform,
        ) as dataset:
            dataset.write(reprojected_embedding)
        
        src_files_to_merge.append(memfile.open())
        
    except Exception as e:
        print(f"    ! WARNING: Could not process tile ({lat:.2f}, {lon:.2f}): {e}")

# --- MERGE ALL REPROJECTED IN-MEMORY FILES ---
if not src_files_to_merge:
    raise RuntimeError("Failed to create any in-memory georeferenced files.")

print("\nPass 2: Merging all tiles...")
merged_array, mosaic_transform = merge(src_files_to_merge)
embedding_mosaic = np.transpose(merged_array, (1, 2, 0)) # (H, W, C)

for src in src_files_to_merge:
    src.close()
    
print(f"Final Embedding Mosaic Shape: {embedding_mosaic.shape}")

mosaic_height, mosaic_width, num_channels = embedding_mosaic.shape
# The mosaic is now in EPSG:4326, so its bounds are already in lat/lon
west, south, east, north = rasterio.transform.array_bounds(mosaic_height, mosaic_width, mosaic_transform)
VIS_BOUNDS = ((south, west), (north, east))

# --- PCA VISUALIZATION ---
print("\nCreating PCA-based visualization...")
pixels = embedding_mosaic.reshape(-1, num_channels)
n_sample = min(pixels.shape[0], 100000)
sample_indices = np.random.choice(pixels.shape[0], n_sample, replace=False)
pca = PCA(n_components=3)
pca.fit(pixels[sample_indices, :])
transformed_pixels = pca.transform(pixels)
pca_image = transformed_pixels.reshape(mosaic_height, mosaic_width, 3)
print("Normalizing PCA components for display...")
vis_mosaic = np.zeros_like(pca_image)
for i in range(3):
    channel = pca_image[:, :, i]
    min_val, max_val = np.percentile(channel, [2, 98])
    if max_val > min_val:
        vis_mosaic[:, :, i] = np.clip((channel - min_val) / (max_val - min_val), 0, 1)
print("PCA visualization created.")

buffer = io.BytesIO()
plt.imsave(buffer, vis_mosaic, format='png')
buffer.seek(0)
b64_data = base64.b64encode(buffer.read()).decode('utf-8')
VIS_DATA_URL = f"data:image/png;base64,{b64_data}"

Searching for tiles in ROI: (-0.1, 52.0, 0.3, 52.2) for year 2024
Found 8 tiles to merge.

Pass 1: Creating reprojected, georeferenced files in memory...
  Processing tile 1/8 at (52.05, -0.05)...
  Processing tile 2/8 at (52.05, 0.05)...
  Processing tile 3/8 at (52.05, 0.15)...
  Processing tile 4/8 at (52.05, 0.25)...
  Processing tile 5/8 at (52.15, -0.05)...
  Processing tile 6/8 at (52.15, 0.05)...
  Processing tile 7/8 at (52.15, 0.15)...
  Processing tile 8/8 at (52.15, 0.25)...

Pass 2: Merging all tiles into a seamless mosaic...
Data fetched and stitched correctly from memory.
Final Embedding Mosaic Shape: (1917, 3858, 128)
Calculated mosaic geographic bounds (WGS84): ((51.997495118546304, -0.10648650491666342), (52.20249798239069, 0.3060857687983307))

Creating PCA-based visualization...
Normalizing PCA components for display...
PCA visualization created.


In [10]:
training_points = []
markers = {} 
A_MARKER_WAS_JUST_REMOVED = False
class_color_map = {}
tab10_cmap = plt.colormaps.get_cmap('tab10')

def get_or_assign_color_for_class(class_name):
    """Assigns a consistent color if one doesn't exist, otherwise returns existing color."""
    if class_name not in class_color_map:
        new_color_index = len(class_color_map) % 10
        class_color_map[class_name] = mcolors.to_hex(tab10_cmap(new_color_index))
    return class_color_map[class_name]

initial_classes = ['Water', 'Urban']
for c in initial_classes:
    get_or_assign_color_for_class(c)

class_dropdown = Dropdown(options=initial_classes, value='Water', description='Class:')
new_class_text = Text(value='', placeholder='Type new class name', description='New Class:')
add_class_button = Button(description="Add")

color_picker = ColorPicker(
    concise=False,
    description='Set Color:',
    value=class_color_map.get(class_dropdown.value, '#FFFFFF'),
    disabled=False
)

opacity_toggle = ToggleButton(value=True, description='Show Embedding', button_style='info')
opacity_slider = FloatSlider(value=0.7, min=0, max=1.0, step=0.05, description='Opacity:')
classify_button = Button(description="Classify")
clear_pins_button = Button(description="Clear All Pins")
clear_classification_button = Button(description="Clear Classification", disabled=True)
filename_text = Text(value='labels.json', placeholder='Enter filename', description='Filename:')
save_button = Button(description="Save Labels", button_style='success')
load_button = Button(description="Load Labels", button_style='primary')
output_log = Output()

basemap_layers = {
    'Esri Satellite': TileLayer(
        url='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attribution='Esri', name='Esri Satellite'
    ),
    'Google Earth': TileLayer(
        url='http://mt0.google.com/vt/lyrs=y&hl=en&x={x}&y={y}&z={z}',
        attribution='Google Earth', name='Google'
    ),
    'Google Maps': TileLayer(
        url='http://mt0.google.com/vt/lyrs=p&hl=en&x={x}&y={y}&z={z}',
        attribution='Google Maps', name='Google'
    )

}

current_basemap = basemap_layers['Esri Satellite']

basemap_selector = Dropdown(
    options=list(basemap_layers.keys()),
    value='Esri Satellite',
    description='Basemap:',
)

map_layout = Layout(height='600px', width='100%')
m = Map(
    layers=(current_basemap,),
    center=((MIN_LAT + MAX_LAT) / 2, (MIN_LON + MAX_LON) / 2), 
    zoom=12,
    layout=map_layout
)
image_overlay = ImageOverlay(url=VIS_DATA_URL, bounds=VIS_BOUNDS, opacity=opacity_slider.value if opacity_toggle.value else 0)
m.add(image_overlay)

# --- EVENT HANDLERS & OBSERVERS ---
def update_opacity(change):
    is_visible = opacity_toggle.value
    image_overlay.opacity = opacity_slider.value if is_visible else 0
    opacity_slider.disabled = not is_visible

def on_add_class_button_clicked(b):
    new_class = new_class_text.value.strip()
    if new_class and new_class not in class_dropdown.options:
        class_dropdown.options += (new_class,)
        class_dropdown.value = new_class
        color_picker.value = get_or_assign_color_for_class(new_class) # Assign a color and update picker
        new_class_text.value = ''
        with output_log: output_log.clear_output(); print(f"Added new class: '{new_class}'")

def on_basemap_change(change):
    global current_basemap
    new_basemap_name = change['new']
    new_layer = basemap_layers[new_basemap_name]
    if current_basemap in m.layers:
        m.remove_layer(current_basemap)
    m.add_layer(new_layer)
    current_basemap = new_layer

def on_class_selection_change(change):
    """When dropdown changes, update the color picker to match."""
    selected_class = change.new
    color = get_or_assign_color_for_class(selected_class)
    color_picker.unobserve(on_color_change, names='value')
    color_picker.value = color
    color_picker.observe(on_color_change, names='value')

def on_color_change(change):
    """When color picker changes, update the map and redraw pins."""
    new_color = change.new
    class_to_update = class_dropdown.value
    
    class_color_map[class_to_update] = new_color
    
    for i, (point, class_name) in enumerate(training_points):
        if class_name == class_to_update:
            coords = point
            marker_key = tuple(coords)
            if marker_key in markers:
                m.remove_layer(markers[marker_key])
            recolored_marker = CircleMarker(location=coords, radius=6, color=new_color, fill_color=new_color, fill_opacity=0.8, weight=1)
            
            # Attach the click-to-remove handler to the recolored marker
            recolored_marker.on_click(partial(remove_marker, marker_key))
            
            m.add(recolored_marker)
            markers[marker_key] = recolored_marker

def remove_marker(marker_key, **kwargs):
    global A_MARKER_WAS_JUST_REMOVED
    
    # Remove from map
    if marker_key in markers:
        m.remove_layer(markers[marker_key])
        del markers[marker_key]
    
    # Remove from training data
    global training_points
    coords_to_remove = marker_key
    training_points = [p for p in training_points if tuple(p[0]) != coords_to_remove]
    
    A_MARKER_WAS_JUST_REMOVED = True
    
    with output_log:
        output_log.clear_output(wait=True)
        print(f"Removed point at ({coords_to_remove[0]:.4f}, {coords_to_remove[1]:.4f}). Total points: {len(training_points)}")

def handle_map_click(**kwargs):
    global A_MARKER_WAS_JUST_REMOVED
    
    # If a marker was just deleted, this click was used
    # Ignore it and reset the flag for the next click
    if A_MARKER_WAS_JUST_REMOVED:
        A_MARKER_WAS_JUST_REMOVED = False
        return

    if kwargs.get('type') == 'click':
        coords = kwargs.get('coordinates')
        selected_class = class_dropdown.value
        
        marker_key = tuple(coords)
        if marker_key in markers:
            with output_log:
                output_log.clear_output(wait=True)
                print("A point already exists at this exact location. Click it to remove.")
            return

        training_points.append((coords, selected_class))
        pin_color = get_or_assign_color_for_class(selected_class)
        marker = CircleMarker(location=coords, radius=6, color=pin_color, fill_color=pin_color, fill_opacity=0.8, weight=1)
        
        marker.on_click(partial(remove_marker, marker_key))
        
        m.add(marker)
        markers[marker_key] = marker
        with output_log:
            output_log.clear_output(wait=True) 
            print(f"Added '{selected_class}' point at ({coords[0]:.4f}, {coords[1]:.4f}). Total points: {len(training_points)}")

def on_clear_pins_button_clicked(b):
    global training_points, markers, class_color_map
    with output_log:
        for key, marker in markers.items(): m.remove_layer(marker)
        training_points, markers, class_color_map = [], {}, {}
        output_log.clear_output(); print("All pins cleared.")
        for c in initial_classes: get_or_assign_color_for_class(c)
        color_picker.value = get_or_assign_color_for_class(class_dropdown.value)


def on_clear_classification_clicked(b):
    global classification_layer
    if classification_layer and classification_layer in m.layers:
        m.remove_layer(classification_layer)
        classification_layer = None
        clear_classification_button.disabled = True
        with output_log: output_log.clear_output(); print("Classification layer removed.")

def on_save_button_clicked(b):
    fname = filename_text.value
    if not fname:
        with output_log: output_log.clear_output(); print("Error: Please provide a filename.")
        return
    
    # Bundle both the points and the color map together for save state
    save_data = {
        'training_points': training_points,
        'class_color_map': class_color_map
    }
    
    try:
        with open(fname, 'w') as f:
            json.dump(save_data, f, indent=2)
        with output_log: output_log.clear_output(); print(f"Successfully saved {len(training_points)} points to {fname}")
    except Exception as e:
        with output_log: output_log.clear_output(); print(f"Error saving file: {e}")

def on_load_button_clicked(b):
    fname = filename_text.value
    if not fname:
        with output_log: output_log.clear_output(); print("Error: Please provide a filename.")
        return
        
    global training_points, markers, class_color_map
    
    try:
        with open(fname, 'r') as f:
            loaded_data = json.load(f)
    except FileNotFoundError:
        with output_log: output_log.clear_output(); print(f"Error: File not found: {fname}")
        return
    except Exception as e:
        with output_log: output_log.clear_output(); print(f"Error loading file: {e}")
        return

    on_clear_pins_button_clicked(None)

    loaded_points = loaded_data.get('training_points', [])
    loaded_colors = loaded_data.get('class_color_map', {})
    
    class_color_map.update(loaded_colors)
    
    # Re draw all markers on the map
    for point_data in loaded_points:
        coords, class_name = point_data
        
        # Add class to dropdown
        if class_name not in class_dropdown.options:
            class_dropdown.options += (class_name,)
        training_points.append(point_data)
        pin_color = get_or_assign_color_for_class(class_name)
        marker = CircleMarker(location=coords, radius=6, color=pin_color, fill_color=pin_color, fill_opacity=0.8, weight=1)
        
        marker_key = tuple(coords)
        marker.on_click(partial(remove_marker, marker_key))
        
        m.add(marker)
        markers[marker_key] = marker

    with output_log: output_log.clear_output(); print(f"Successfully loaded {len(training_points)} points from {fname}")

# --- ATTACH ALL HANDLERS AND OBSERVERS ---
opacity_toggle.observe(update_opacity, names='value')
opacity_slider.observe(update_opacity, names='value')
add_class_button.on_click(on_add_class_button_clicked)
class_dropdown.observe(on_class_selection_change, names='value')
color_picker.observe(on_color_change, names='value')
m.on_interaction(handle_map_click)
clear_pins_button.on_click(on_clear_pins_button_clicked)
clear_classification_button.on_click(on_clear_classification_clicked)
basemap_selector.observe(on_basemap_change, names='value')
save_button.on_click(on_save_button_clicked)
load_button.on_click(on_load_button_clicked)

# --- LAYOUT THE UI ---
class_controls = HBox([class_dropdown, color_picker])
new_class_controls = HBox([new_class_text, add_class_button])
opacity_controls = HBox([opacity_toggle, opacity_slider])
controls = VBox([basemap_selector, class_controls, new_class_controls, opacity_controls])
buttons = HBox([classify_button, clear_pins_button, clear_classification_button])
file_controls = HBox([filename_text, save_button, load_button])
ui = VBox([controls, m, buttons, file_controls, output_log])

display(ui)

VBox(children=(VBox(children=(Dropdown(description='Basemap:', options=('Esri Satellite', 'Google Earth', 'Goo…

In [11]:
# -- 4. CLASSIFICATION STAGE --

# A global variable to hold the classification layer so we can remove it later
classification_layer = None

def on_classify_button_clicked(b):
    global classification_layer
    
    with output_log:
        output_log.clear_output()
        
        if len(training_points) < 2 or len(set(c for p, c in training_points)) < 2:
            print("ERROR: Please add at least two points from two different classes to train the model.")
            return
            
        print("Starting classification...")
        X_train, y_train = [], []
        
        # Create a mapping from class names (e.g., 'Water') to integer labels (e.g., 0)
        unique_class_names = sorted(list(set(name for point, name in training_points)))
        class_index_map = {name: i for i, name in enumerate(unique_class_names)}
        print(f"Discovered classes for training: {unique_class_names}")
        
        # 2. --- PREPARE TRAINING DATA (THE KEY CHANGE) ---
        print(f"Mapping {len(training_points)} training points to pixel coordinates...")
        
        # Get mosaic dimensions from the global variable created in the previous cell
        mosaic_height, mosaic_width, num_channels = embedding_mosaic.shape
        
        for (lat, lon), class_name in training_points:
            row, col = rasterio.transform.rowcol(mosaic_transform, lon, lat)
            if 0 <= row < mosaic_height and 0 <= col < mosaic_width:
                X_train.append(embedding_mosaic[row, col, :])
                y_train.append(class_index_map[class_name])
            else:
                # This can happen if a user clicks just outside the reprojected mosaic area
                print(f"  - WARNING: Skipping point for '{class_name}' at ({lat:.4f}, {lon:.4f}) as it falls outside the mosaic's bounds.")

        if not X_train:
            print("ERROR: None of the training points were inside the mosaic bounds.")
            return

        # 3. --- TRAIN THE MODEL ---
        print(f"Training k-NN classifier on {len(X_train)} valid points...")
        # Use a sensible k, ensuring it's not larger than the number of samples
        k = min(5, len(X_train))
        model = KNeighborsClassifier(n_neighbors=k, weights='distance')
        model.fit(X_train, y_train)
        
        # 4. --- PREDICT ON THE ENTIRE MOSAIC ---
        # print("Predicting on the full image... (this may take a moment)")
        # all_pixels = embedding_mosaic.reshape(-1, num_channels)
        # predicted_labels = model.predict(all_pixels)
        # classification_result = predicted_labels.reshape(mosaic_height, mosaic_width)
        
        from tqdm import tqdm

        all_pixels = embedding_mosaic.reshape(-1, num_channels)
        n_pixels = all_pixels.shape[0]
        batch_size = 25000
        predicted_labels = np.zeros(n_pixels, dtype=np.uint8)

        # Use tqdm with total set to n_pixels and update it by batch size
        with tqdm(total=n_pixels, desc="Classifying pixels") as pbar:
            for i in range(0, n_pixels, batch_size):
                end = min(i + batch_size, n_pixels)
                predicted_labels[i:end] = model.predict(all_pixels[i:end, :])
                pbar.update(end - i)  # update progress bar by number of processed pixels

        classification_result = predicted_labels.reshape(mosaic_height, mosaic_width)

        # Clean up the reshaped array
        del all_pixels

        # 5. --- VISUALIZE AND DISPLAY THE RESULT ---
        print("Creating visualization of the classification result...")
        
        # Create a colormap for the final image using the same colors as the pins
        color_list = [get_or_assign_color_for_class(name) for name in unique_class_names]
        cmap = mcolors.ListedColormap(color_list)
        norm = mcolors.Normalize(vmin=0, vmax=len(unique_class_names)-1)
        
        # Apply the colormap to the integer-labeled result array
        colored_result = cmap(norm(classification_result))

        # Convert the numpy array to a PNG image in memory
        buffer = io.BytesIO()
        plt.imsave(buffer, colored_result, format='png')
        buffer.seek(0)
        b64_data = base64.b64encode(buffer.read()).decode('utf-8')
        classification_data_url = f"data:image/png;base64,{b64_data}"

        print("Displaying result on the map...")
        # If an old classification layer exists, remove it first
        if classification_layer and classification_layer in m.layers:
            m.remove_layer(classification_layer)

        # Create the new ImageOverlay for the classification
        classification_layer = ipyleaflet.ImageOverlay(
            url=classification_data_url,
            bounds=VIS_BOUNDS,
            opacity=0.7,
            name='Classification'
        )
        m.add(classification_layer)
        
        # Enable the clear button now that a layer exists
        clear_classification_button.disabled = False
        print("Classification complete.")

# Attach the function to the button's click event
classify_button.on_click(on_classify_button_clicked)