In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from sklearn.neighbors import KNeighborsClassifier
import ipyleaflet
from ipyleaflet import (
    TileLayer,
    Map,
    ImageOverlay,
    CircleMarker,
    LayersControl,
    WidgetControl,
)
from ipywidgets import *
from geotessera import GeoTessera
from sklearn.decomposition import PCA
import time
from functools import partial
import io
import base64
import json
import rasterio
from rasterio.io import MemoryFile
from rasterio.merge import merge
from rasterio.warp import calculate_default_transform, reproject
from rasterio.enums import Resampling
from tqdm import tqdm

from interactive import utils, config, visualisation

Config = config.Config()

In [3]:
# -- 1. ROI DEFINITION --
# This is in increments of 0.1, and requires at least a 0.1 by 0.1 deg bbox

# update coordinates here, otherwise defaults will be used (South Cambridge)
MIN_LON, MAX_LON = None, None
MIN_LAT, MAX_LAT = None, None

(MIN_LAT, MAX_LAT), (MIN_LON, MAX_LON) = utils.define_roi(
    lat_coords=(MIN_LAT, MAX_LAT),
    lon_coords=(MIN_LON, MAX_LON),
)

Using config defaults for lat_coords: [-0.1, 0.3] and lon_coords: [52.0, 52.2]
Bounding box defined:
┗ (52.00, -0.10) | ┓ (52.20, 0.30)


In [8]:
# -- 2. FETCH AND MOSAIC RELEVANT TESSERA TILES --

embedding_mosaic, mosaic_transform = utils.TesseraUtils().process_roi_to_mosaic(
    lat_coords=(MIN_LAT, MAX_LAT),
    lon_coords=(MIN_LON, MAX_LON),
)

Bounding box defined:
┗ (52.00, -0.10) | ┓ (52.20, 0.30)

Searching for tiles in ROI: (-0.1, 52.0, 0.3, 52.2) for year 2024

Found 15 tiles to merge.


Processing tiles:   0%|          | 0/15 [00:00<?, ?it/s]


Merging all tiles...
Shape of final embedding mosaic: (2853, 4791, 128)


In [None]:
# -- 3. VISUALISE AND PLACE TRAINING POINTS --

mapping_tool = visualisation.InteractiveMappingTool(
    MIN_LAT,
    MAX_LAT,
    MIN_LON,
    MAX_LON,
    embedding_mosaic,
    mosaic_transform,
)
mapping_tool.display()

Calculated bounds: ((51.89743052830078, -0.20626269731868993), (52.202497982390696, 0.3060324910951864))

Creating PCA-based visualization...
Normalizing PCA components for display...
PCA visualization created.
Image overlay added to map


VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Basemap:', options=('Esri Satellite', 'Goog…

In [None]:
# -- 4. CLASSIFICATION STAGE --

# A global variable to hold the classification layer so we can remove it later
classification_layer = None


def on_classify_button_clicked(b):
    global classification_layer

    with output_log:
        output_log.clear_output()

        if len(training_points) < 2 or len(set(c for p, c in training_points)) < 2:
            print(
                "ERROR: Please add at least two points from two different classes to train the model."
            )
            return

        print("Starting classification...")
        X_train, y_train = [], []

        # Create a mapping from class names (e.g., 'Water') to integer labels (e.g., 0)
        unique_class_names = sorted(list(set(name for point, name in training_points)))
        class_index_map = {name: i for i, name in enumerate(unique_class_names)}
        print(f"Discovered classes for training: {unique_class_names}")

        # 2. --- PREPARE TRAINING DATA (THE KEY CHANGE) ---
        print(f"Mapping {len(training_points)} training points to pixel coordinates...")

        # Get mosaic dimensions from the global variable created in the previous cell
        mosaic_height, mosaic_width, num_channels = embedding_mosaic.shape

        for (lat, lon), class_name in training_points:
            row, col = rasterio.transform.rowcol(mosaic_transform, lon, lat)
            if 0 <= row < mosaic_height and 0 <= col < mosaic_width:
                X_train.append(embedding_mosaic[row, col, :])
                y_train.append(class_index_map[class_name])
            else:
                # This can happen if a user clicks just outside the reprojected mosaic area
                print(
                    f"  - WARNING: Skipping point for '{class_name}' at ({lat:.4f}, {lon:.4f}) as it falls outside the mosaic's bounds."
                )

        if not X_train:
            print("ERROR: None of the training points were inside the mosaic bounds.")
            return

        # 3. --- TRAIN THE MODEL ---
        print(f"Training k-NN classifier on {len(X_train)} valid points...")
        # Use a sensible k, ensuring it's not larger than the number of samples
        k = min(5, len(X_train))
        model = KNeighborsClassifier(n_neighbors=k, weights="distance")
        model.fit(X_train, y_train)

        # 4. --- PREDICT ON THE ENTIRE MOSAIC ---
        # print("Predicting on the full image... (this may take a moment)")
        # all_pixels = embedding_mosaic.reshape(-1, num_channels)
        # predicted_labels = model.predict(all_pixels)
        # classification_result = predicted_labels.reshape(mosaic_height, mosaic_width)

        all_pixels = embedding_mosaic.reshape(-1, num_channels)
        n_pixels = all_pixels.shape[0]
        batch_size = 15000
        predicted_labels = np.zeros(n_pixels, dtype=np.uint8)

        # Use tqdm with total set to n_pixels and update it by batch size
        with tqdm(total=n_pixels, desc="Classifying pixels") as pbar:
            for i in range(0, n_pixels, batch_size):
                end = min(i + batch_size, n_pixels)
                predicted_labels[i:end] = model.predict(all_pixels[i:end, :])
                pbar.update(
                    end - i
                )  # update progress bar by number of processed pixels

        classification_result = predicted_labels.reshape(mosaic_height, mosaic_width)

        # Clean up the reshaped array
        del all_pixels

        # 5. --- VISUALIZE AND DISPLAY THE RESULT ---
        print("Creating visualization of the classification result...")

        # Create a colormap for the final image using the same colors as the pins
        color_list = [
            get_or_assign_color_for_class(name) for name in unique_class_names
        ]
        cmap = mcolors.ListedColormap(color_list)
        norm = mcolors.Normalize(vmin=0, vmax=len(unique_class_names) - 1)

        # Apply the colormap to the integer-labeled result array
        colored_result = cmap(norm(classification_result))

        # Convert the numpy array to a PNG image in memory
        buffer = io.BytesIO()
        plt.imsave(buffer, colored_result, format="png")
        buffer.seek(0)
        b64_data = base64.b64encode(buffer.read()).decode("utf-8")
        classification_data_url = f"data:image/png;base64,{b64_data}"

        print("Displaying result on the map...")
        # If an old classification layer exists, remove it first
        if classification_layer and classification_layer in m.layers:
            m.remove_layer(classification_layer)

        # Create the new ImageOverlay for the classification
        classification_layer = ipyleaflet.ImageOverlay(
            url=classification_data_url,
            bounds=VIS_BOUNDS,
            opacity=0.7,
            name="Classification",
        )
        m.add(classification_layer)

        # Enable the clear button now that a layer exists
        clear_classification_button.disabled = False
        print("Classification complete.")


# Attach the function to the button's click event
classify_button.on_click(on_classify_button_clicked)