In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.neighbors import KNeighborsClassifier
import ipyleaflet
from ipyleaflet import TileLayer, Map, ImageOverlay, CircleMarker, LayersControl
from ipywidgets import *
from geotessera import GeoTessera
from sklearn.decomposition import PCA
import time

In [2]:
MIN_LON, MAX_LON = -0.10, 0.30
MIN_LAT, MAX_LAT = 52.00, 52.20

print("Area of Interest defined.")

Area of Interest defined.


In [3]:
# -- 2. FETCH, STITCH, AND PREPARE PCA VISUALIZATION --

tessera = GeoTessera()

# Find all available tiles
unique_lons = sorted(list(set(lon for year, lat, lon in tessera.list_available_embeddings() if MIN_LON <= lon < MAX_LON and MIN_LAT <= lat < MAX_LAT)))
unique_lats = sorted(list(set(lat for year, lat, lon in tessera.list_available_embeddings() if MIN_LON <= lon < MAX_LON and MIN_LAT <= lat < MAX_LAT)), reverse=True)

if not unique_lons or not unique_lats:
    raise ValueError("No tiles found for the specified ROI.")

print(f"Found {len(unique_lats) * len(unique_lons)} tiles.")

# -- PASS 1: METADATA SCAN --
tile_data, tile_shapes = {}, {}
print("Pass 1: Fetching data and scanning tile shapes...")
for lat in unique_lats:
    for lon in unique_lons:
        print(f"Fetching tile at ({lat:.2f}, {lon:.2f})...")
        try:
            data = tessera.get_embedding(lat, lon)
            tile_data[(lat, lon)], tile_shapes[(lat, lon)] = data, data.shape
        except Exception as e:
            print(f"  Could not fetch tile: {e}")
            tile_data[(lat, lon)] = None

col_widths = {lon: max(tile_shapes.get((lat, lon), (0,0))[1] for lat in unique_lats) for lon in unique_lons}
row_heights = {lat: max(tile_shapes.get((lat, lon), (0,0))[0] for lon in unique_lons) for lat in unique_lats}
col_starts, current_x = {}, 0
for lon in unique_lons:
    col_starts[lon] = current_x
    current_x += col_widths[lon]
mosaic_width = current_x
row_starts, current_y = {}, 0
for lat in unique_lats:
    row_starts[lat] = current_y
    current_y += row_heights[lat]
mosaic_height = current_y
print(f"Calculated mosaic dimensions: {mosaic_width}px width, {mosaic_height}px height")

# -- PASS 2: STITCHING --
num_channels = next(iter(tile_data.values())).shape[2]
embedding_mosaic = np.zeros((mosaic_height, mosaic_width, num_channels), dtype=np.float32)
print("Pass 2: Stitching tiles into mosaic...")
for (lat, lon), data in tile_data.items():
    if data is not None:
        h, w, _ = data.shape
        y_start, x_start = row_starts[lat], col_starts[lon]
        embedding_mosaic[y_start : y_start+h, x_start : x_start+w, :] = data
print("Data fetched and stitched.")
print(f"Embedding Mosaic Shape: {embedding_mosaic.shape}")

# --- PCA VISUALIZATION ---
print("\nCreating PCA-based visualization...")
pixels = embedding_mosaic.reshape(-1, num_channels)

# 2. Fit PCA on a random subsample for efficiency
n_pixels_total = pixels.shape[0]
n_sample = min(n_pixels_total, 100000)
sample_indices = np.random.choice(n_pixels_total, n_sample, replace=False)
sample_pixels = pixels[sample_indices, :]

print(f"Fitting PCA on a sample of {n_sample} pixels...")
pca = PCA(n_components=3)
pca.fit(pixels)

# 3. Transform all pixels using the fitted model
print("Transforming all pixels with the PCA model...")
transformed_pixels = pca.transform(pixels)
pca_image = transformed_pixels.reshape(mosaic_height, mosaic_width, 3)

# 5. Normalize each channel to the [0, 1] range for display
print("Normalizing PCA components for display...")
vis_mosaic = np.zeros_like(pca_image)
for i in range(3):
    channel = pca_image[:, :, i]
    min_val, max_val = np.percentile(channel, [2, 98])
    if max_val > min_val:
        vis_mosaic[:, :, i] = np.clip((channel - min_val) / (max_val - min_val), 0, 1)

print("PCA visualization created.")

# Generate a unique filename using a timestamp
# This prevents jupyter notebook from caching the map
timestamp = int(time.time())
VIS_FILENAME = f"temp_vis_{timestamp}.png"

# Save the visualization to this filename
plt.imsave(VIS_FILENAME, vis_mosaic)
print(f"Saved visualization to {VIS_FILENAME}")

Found 8 tiles.
Pass 1: Fetching data and scanning tile shapes...
Fetching tile at (52.15, -0.05)...
Fetching tile at (52.15, 0.05)...
Fetching tile at (52.15, 0.15)...
Fetching tile at (52.15, 0.25)...
Fetching tile at (52.05, -0.05)...
Fetching tile at (52.05, 0.05)...
Fetching tile at (52.05, 0.15)...
Fetching tile at (52.05, 0.25)...
Calculated mosaic dimensions: 2919px width, 2280px height
Pass 2: Stitching tiles into mosaic...
Data fetched and stitched.
Embedding Mosaic Shape: (2280, 2919, 128)

Creating PCA-based visualization...
Transforming all pixels with the PCA model...
Normalizing PCA components for display...
PCA visualization created.
Saved visualization to temp_vis_1751645619.png


In [4]:
training_points = []
markers = {} 

class_color_map = {}
tab10_cmap = plt.colormaps.get_cmap('tab10')

def get_or_assign_color_for_class(class_name):
    """Assigns a consistent color if one doesn't exist, otherwise returns existing color."""
    if class_name not in class_color_map:
        new_color_index = len(class_color_map) % 10
        class_color_map[class_name] = mcolors.to_hex(tab10_cmap(new_color_index))
    return class_color_map[class_name]

initial_classes = ['Water', 'Urban']
for c in initial_classes:
    get_or_assign_color_for_class(c)

class_dropdown = Dropdown(options=initial_classes, value='Water', description='Class:')
new_class_text = Text(value='', placeholder='Type new class name', description='New Class:')
add_class_button = Button(description="Add")

color_picker = ColorPicker(
    concise=False,
    description='Set Color:',
    value=class_color_map.get(class_dropdown.value, '#FFFFFF'),
    disabled=False
)

opacity_toggle = ToggleButton(value=True, description='Show Embedding', button_style='info')
opacity_slider = FloatSlider(value=0.7, min=0, max=1.0, step=0.05, description='Opacity:')
classify_button = Button(description="Classify")
clear_pins_button = Button(description="Clear All Pins")
clear_classification_button = Button(description="Clear Classification", disabled=True)
output_log = Output()

basemap_layers = {
    'Esri Satellite': TileLayer(
        url='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attribution='Esri', name='Esri Satellite'
    ),
    'Google Earth': TileLayer(
        url='http://mt0.google.com/vt/lyrs=y&hl=en&x={x}&y={y}&z={z}',
        attribution='Google Earth', name='Google'
    ),
    'Google Maps': TileLayer(
        url='http://mt0.google.com/vt/lyrs=p&hl=en&x={x}&y={y}&z={z}',
        attribution='Google Maps', name='Google'
    )

}

current_basemap = basemap_layers['Esri Satellite']

basemap_selector = Dropdown(
    options=list(basemap_layers.keys()),
    value='Esri Satellite',
    description='Basemap:',
)

map_layout = Layout(height='600px', width='100%')
m = Map(
    layers=(current_basemap,),
    center=((MIN_LAT + MAX_LAT) / 2, (MIN_LON + MAX_LON) / 2), 
    zoom=12,
    layout=map_layout
)
image_overlay = ImageOverlay(url=VIS_FILENAME, bounds=((MIN_LAT, MIN_LON), (MAX_LAT, MAX_LON)), opacity=opacity_slider.value if opacity_toggle.value else 0)
m.add(image_overlay)

# --- EVENT HANDLERS & OBSERVERS ---
def update_opacity(change):
    is_visible = opacity_toggle.value
    image_overlay.opacity = opacity_slider.value if is_visible else 0
    opacity_slider.disabled = not is_visible

def on_add_class_button_clicked(b):
    new_class = new_class_text.value.strip()
    if new_class and new_class not in class_dropdown.options:
        class_dropdown.options += (new_class,)
        class_dropdown.value = new_class
        color_picker.value = get_or_assign_color_for_class(new_class) # Assign a color and update picker
        new_class_text.value = ''
        with output_log: output_log.clear_output(); print(f"Added new class: '{new_class}'")

def on_basemap_change(change):
    global current_basemap
    new_basemap_name = change['new']
    new_layer = basemap_layers[new_basemap_name]
    if current_basemap in m.layers:
        m.remove_layer(current_basemap)
    m.add_layer(new_layer)
    current_basemap = new_layer

def on_class_selection_change(change):
    """When dropdown changes, update the color picker to match."""
    selected_class = change.new
    color = get_or_assign_color_for_class(selected_class)
    color_picker.unobserve(on_color_change, names='value')
    color_picker.value = color
    color_picker.observe(on_color_change, names='value')

def on_color_change(change):
    """When color picker changes, update the map and redraw pins."""
    new_color = change.new
    class_to_update = class_dropdown.value
    
    class_color_map[class_to_update] = new_color
    
    for i, (point, class_name) in enumerate(training_points):
        if class_name == class_to_update:
            coords = point
            marker_key = tuple(coords)
            if marker_key in markers:
                m.remove_layer(markers[marker_key])
            recolored_marker = CircleMarker(location=coords, radius=6, color=new_color, fill_color=new_color, fill_opacity=0.8, weight=1)
            m.add(recolored_marker)
            markers[marker_key] = recolored_marker

def handle_map_click(**kwargs):
    if kwargs.get('type') == 'click':
        coords = kwargs.get('coordinates')
        selected_class = class_dropdown.value
        training_points.append((coords, selected_class))
        pin_color = get_or_assign_color_for_class(selected_class)
        marker = CircleMarker(location=coords, radius=6, color=pin_color, fill_color=pin_color, fill_opacity=0.8, weight=1)
        marker_key = tuple(coords) 
        m.add(marker)
        markers[marker_key] = marker
        with output_log:
            output_log.clear_output(wait=True) 
            print(f"Added '{selected_class}' point at ({coords[0]:.4f}, {coords[1]:.4f}). Total points: {len(training_points)}")

def on_clear_pins_button_clicked(b):
    global training_points, markers, class_color_map
    with output_log:
        for key, marker in markers.items(): m.remove_layer(marker)
        training_points, markers, class_color_map = [], {}, {}
        output_log.clear_output(); print("All pins cleared.")
        for c in initial_classes: get_or_assign_color_for_class(c)
        color_picker.value = get_or_assign_color_for_class(class_dropdown.value)


def on_clear_classification_clicked(b):
    global classification_layer
    if classification_layer and classification_layer in m.layers:
        m.remove_layer(classification_layer)
        classification_layer = None
        clear_classification_button.disabled = True
        with output_log: output_log.clear_output(); print("Classification layer removed.")

# --- ATTACH ALL HANDLERS AND OBSERVERS ---
opacity_toggle.observe(update_opacity, names='value')
opacity_slider.observe(update_opacity, names='value')
add_class_button.on_click(on_add_class_button_clicked)
class_dropdown.observe(on_class_selection_change, names='value')
color_picker.observe(on_color_change, names='value')
m.on_interaction(handle_map_click)
clear_pins_button.on_click(on_clear_pins_button_clicked)
clear_classification_button.on_click(on_clear_classification_clicked)
basemap_selector.observe(on_basemap_change, names='value')

# --- LAYOUT THE UI ---
class_controls = HBox([class_dropdown, color_picker])
new_class_controls = HBox([new_class_text, add_class_button])
opacity_controls = HBox([opacity_toggle, opacity_slider])
controls = VBox([basemap_selector, class_controls, new_class_controls, opacity_controls])
buttons = HBox([classify_button, clear_pins_button, clear_classification_button])
ui = VBox([controls, m, buttons, output_log])

display(ui)

VBox(children=(VBox(children=(Dropdown(description='Basemap:', options=('Esri Satellite', 'Google Earth', 'Goo…

In [5]:
classification_layer = None

def on_classify_button_clicked(b):
    global classification_layer
    
    with output_log:
        output_log.clear_output()
        if len(training_points) < 2 or len(set(c for p, c in training_points)) < 2:
            print("Please add at least two points from two different classes.")
            return
            
        print("Starting classification...")
        X_train, y_train = [], []
        
        unique_class_names = sorted(list(set(name for point, name in training_points)))
        class_index_map = {name: i for i, name in enumerate(unique_class_names)}
        print(f"Discovered classes for training: {unique_class_names}")
        
        for (lat, lon), class_name in training_points:
            row = int(mosaic_height * (MAX_LAT - lat) / (MAX_LAT - MIN_LAT))
            col = int(mosaic_width * (lon - MIN_LON) / (MAX_LON - MIN_LON))
            row, col = np.clip(row, 0, mosaic_height - 1), np.clip(col, 0, mosaic_width - 1)
            X_train.append(embedding_mosaic[row, col, :])
            y_train.append(class_index_map[class_name])

        print(f"Training k-NN on {len(X_train)} points...")
        k = min(5, len(X_train))
        model = KNeighborsClassifier(n_neighbors=k)
        model.fit(X_train, y_train)
        
        print("Predicting on the full image...")
        all_pixels = embedding_mosaic.reshape(-1, num_channels)
        predicted_labels = model.predict(all_pixels)
        classification_result = predicted_labels.reshape(mosaic_height, mosaic_width)
        
        # Create a colormap for the final image using the same colors as the pins
        # Get the list of hex colors in the correct order
        color_list = [get_or_assign_color_for_class(name) for name in unique_class_names]
        cmap = mcolors.ListedColormap(color_list)

        # Normalize the result to be in the range of the colormap
        norm = mcolors.Normalize(vmin=0, vmax=len(unique_class_names)-1)
        colored_result = cmap(norm(classification_result))
        plt.imsave("temp_classification.png", colored_result)
        
        print("Displaying result...")
        if classification_layer in m.layers:
            m.remove_layer(classification_layer)

        classification_layer = ipyleaflet.ImageOverlay(
            url="temp_classification.png",
            bounds=((MIN_LAT, MIN_LON), (MAX_LAT, MAX_LON)),
            opacity=0.7,
            name='Classification'
        )
        m.add(classification_layer)
        # Enable the clear button now that a layer exists
        clear_classification_button.disabled = False
        print("Classification complete")
        
        # This doesn't work, lol
        #legend_parts = [f"'{name}': {class_color_map[name]}" for name in unique_class_names]
        #print(f"Legend: {{ {', '.join(legend_parts)} }}")

# Attach the function to the button's click event
classify_button.on_click(on_classify_button_clicked)