# Somerset Landsat layer explorer

This notebook visualises the outputs generated by `examples/example_workflow.py`. Run the example script first so the Landsat mosaics, water masks, salinity rasters, and land-cover overlays exist on disk.


## Prerequisites

* Execute `python examples/example_workflow.py` (or run `process_landsat_history()` from the module) before opening this notebook so the processed rasters are cached under `swmaps/data/examples/somerset_landsat`.
* The notebook relies on `folium`, `geopandas`, `ipywidgets`, `matplotlib`, `numpy`, `Pillow`, and `rasterio`. Install any missing packages with `pip install` before running the cells. If using the environment provided with the repo the libraries are included.


In [None]:
from __future__ import annotations

import math
import warnings
from pathlib import Path

import folium
import geopandas as gpd
import ipywidgets as widgets
import numpy as np
from IPython.display import display
import rasterio
from rasterio.enums import Resampling
from rasterio.warp import transform_bounds
import matplotlib.pyplot as plt

In [None]:
# Resolve repository paths and locate the processed Somerset outputs
NOTEBOOK_DIR = Path.cwd()
REPO_ROOT = next(
    (
        candidate
        for candidate in [NOTEBOOK_DIR, *NOTEBOOK_DIR.parents]
        if (candidate / "config" / "somerset.geojson").exists()
    ),
    None,
)
if REPO_ROOT is None:
    raise RuntimeError("Could not locate the repository root from this notebook directory.")

GEOJSON_PATH = REPO_ROOT / "config" / "somerset.geojson"

try:
    from swmaps.config import data_path as _data_path
except Exception as exc:  # pragma: no cover - fallback for bare repository checkouts
    warnings.warn(
        "Falling back to the default data directory under 'swmaps/data' because "
        f"swmaps.config.data_path could not be imported: {exc}"
    )
    DATA_ROOT = REPO_ROOT / "swmaps" / "data"

    def data_path(*parts: str | Path) -> Path:
        return DATA_ROOT.joinpath(*parts)
else:
    data_path = _data_path

OUTPUT_ROOT = data_path("examples", "somerset_landsat")
if not OUTPUT_ROOT.exists():
    raise FileNotFoundError(
        "No processed rasters were found. Run examples/example_workflow.py first."
    )

aoi = gpd.read_file(GEOJSON_PATH).to_crs("EPSG:4326")
aoi_center = aoi.geometry.unary_union.representative_point()
MAP_CENTER = [aoi_center.y, aoi_center.x]

In [None]:
# Helper utilities for building map overlays
SALINITY_CLASS_COLORS = {
    0: (0, 0, 0, 0),  # land -> fully transparent
    1: (102, 194, 165, 220),  # fresh water
    2: (44, 162, 95, 220),  # brackish
    3: (0, 109, 44, 220),  # saline
    255: (0, 0, 0, 0),  # nodata
}


def format_bounds(src: rasterio.io.DatasetReader) -> list[list[float]]:
    minx, miny, maxx, maxy = transform_bounds(src.crs, "EPSG:4326", *src.bounds)
    return [[miny, minx], [maxy, maxx]]


def read_resampled(
    src: rasterio.io.DatasetReader,
    indexes: tuple[int, ...],
    *,
    max_pixels: int = 750_000,
    resampling: Resampling = Resampling.bilinear,
) -> np.ndarray:
    scale = min(1.0, math.sqrt(max_pixels / (src.width * src.height)))
    out_height = max(1, int(round(src.height * scale)))
    out_width = max(1, int(round(src.width * scale)))
    data = src.read(indexes, out_shape=(len(indexes), out_height, out_width), resampling=resampling)
    return data


def natural_color_image(path: Path, *, max_pixels: int = 750_000) -> tuple[np.ndarray, list[list[float]]]:
    with rasterio.open(path) as src:
        data = read_resampled(src, (3, 2, 1), max_pixels=max_pixels)
        bounds = format_bounds(src)

    rgb = np.moveaxis(data, 0, -1).astype(np.float32)
    valid = np.any(np.isfinite(rgb), axis=-1)
    for band in range(3):
        band_data = rgb[..., band]
        finite = band_data[np.isfinite(band_data)]
        if finite.size:
            low, high = np.percentile(finite, (2, 98))
            if high <= low:
                high = low + 1e-6
            scaled = (band_data - low) / (high - low)
        else:
            scaled = np.zeros_like(band_data)
        rgb[..., band] = np.where(np.isfinite(band_data), np.clip(scaled, 0.0, 1.0), 0.0)

    alpha = np.where(valid, 255, 0).astype(np.uint8)
    rgba = (np.clip(rgb, 0.0, 1.0) * 255).astype(np.uint8)
    rgba = np.dstack([rgba, alpha])
    return rgba, bounds


def single_band_colormap_image(
    path: Path,
    *,
    colormap: str = "viridis",
    percentiles: tuple[float, float] | None = (2.0, 98.0),
    max_pixels: int = 750_000,
    resampling: Resampling = Resampling.bilinear,
) -> tuple[np.ndarray, list[list[float]]]:
    with rasterio.open(path) as src:
        data = read_resampled(src, (1,), max_pixels=max_pixels, resampling=resampling)[0].astype(np.float32)
        bounds = format_bounds(src)
        nodata = src.nodata

    if nodata is not None:
        data = np.where(data == nodata, np.nan, data)

    if percentiles is not None:
        finite = data[np.isfinite(data)]
        if finite.size:
            vmin, vmax = np.percentile(finite, percentiles)
            if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
                vmin, vmax = float(finite.min()), float(finite.max() + 1e-6)
        else:
            vmin, vmax = 0.0, 1.0
    else:
        vmin, vmax = np.nanmin(data), np.nanmax(data)
        if not np.isfinite(vmin) or not np.isfinite(vmax) or vmin == vmax:
            vmin, vmax = 0.0, 1.0

    with np.errstate(invalid="ignore", divide="ignore"):
        scaled = (data - vmin) / (vmax - vmin)
    scaled = np.where(np.isfinite(data), np.clip(scaled, 0.0, 1.0), 0.0)
    cmap = plt.get_cmap(colormap)
    rgba = cmap(scaled)
    rgba[..., 3] = np.where(np.isfinite(data), rgba[..., 3], 0.0)
    rgba_uint8 = (np.clip(rgba, 0.0, 1.0) * 255).astype(np.uint8)
    return rgba_uint8, bounds


def classification_image(
    path: Path,
    *,
    palette: dict[int, tuple[int, int, int, int]] | None = None,
    nodata_values: tuple[int, ...] | None = None,
    max_pixels: int = 750_000,
) -> tuple[np.ndarray, list[list[float]]]:
    with rasterio.open(path) as src:
        data = read_resampled(src, (1,), max_pixels=max_pixels, resampling=Resampling.nearest)[0].astype(np.int32)
        bounds = format_bounds(src)
        nodata = src.nodata

        # build mask
        mask = np.zeros_like(data, dtype=bool)
        if nodata is not None:
            mask |= data == nodata
        if nodata_values is not None:
            mask |= np.isin(data, nodata_values)

        rgba = np.zeros((data.shape[0], data.shape[1], 4), dtype=np.uint8)

        if palette:
            # use caller-provided palette
            for value, color in palette.items():
                rgba[data == value] = color
        else:
            # try file colormap first
            try:
                colormap = src.colormap(1)  # dict {val: (r,g,b,a)}
                for val, color in colormap.items():
                    rgba[data == val] = color
            except ValueError:
                # fallback to tab20 if no colormap
                unique = np.unique(data[~mask])
                if unique.size:
                    cmap = plt.get_cmap("tab20", max(unique.size, 1))
                    for idx, value in enumerate(unique):
                        rgba[data == value] = (np.array(cmap(idx)) * 255).astype(np.uint8)

        rgba[mask] = (0, 0, 0, 0)
        return rgba, bounds



def discover_products(output_root: Path) -> dict[str, dict[str, Path | str]]:
    suffixes = {
        "water_mask": "_ndwi_mask",
        "salinity_score": "_salinity_score",
        "salinity_class": "_salinity_class",
        "salinity_water_mask": "_salinity_water_mask",
    }

    products: dict[str, dict[str, Path | str]] = {}
    for mosaic_path in sorted(output_root.glob("*_somerset_*.tif")):
        stem = mosaic_path.stem
        if stem.startswith("nlcd_") or stem.startswith("cdl_"):
            continue
        if any(stem.endswith(suffix) for suffix in suffixes.values()):
            continue

        mission_tag, _, date_token = stem.partition("_somerset_")
        if not mission_tag or not date_token:
            continue

        try:
            start_date, end_date = date_token.split("_")
        except ValueError:
            start_date, end_date = date_token, ""

        mission_label = mission_tag.replace("landsat", "Landsat ").strip().title()
        label = f"{mission_label} — {start_date} to {end_date}" if end_date else f"{mission_label} — {start_date}"

        record: dict[str, Path | str] = {
            "mission_tag": mission_tag,
            "date_token": date_token,
            "label": label,
            "mosaic": mosaic_path,
        }
        for key, suffix in suffixes.items():
            record[key] = mosaic_path.with_name(f"{stem}{suffix}.tif")

        # Inside discover_products, when assigning NLCD/CDL
        base_nlcd = output_root / f"nlcd_{mission_tag}_{date_token}.tif"
        approx_nlcd = output_root / f"nlcd_{mission_tag}_{date_token}_approx.tif"
        record["nlcd"] = base_nlcd if base_nlcd.exists() else approx_nlcd
        
        base_cdl = output_root / f"cdl_{mission_tag}_{date_token}.tif"
        approx_cdl = output_root / f"cdl_{mission_tag}_{date_token}_approx.tif"
        record["cdl"] = base_cdl if base_cdl.exists() else approx_cdl

        products[stem] = record

    return products


def create_map(scene_id: str) -> folium.Map:
    info = PRODUCTS[scene_id]
    m = folium.Map(location=MAP_CENTER, zoom_start=9, tiles="CartoDB Positron")

    folium.GeoJson(
        data=aoi.__geo_interface__,
        name="Somerset AOI",
        style_function=lambda _: {"color": "#ff7f00", "weight": 2, "fill": False},
    ).add_to(m)

    mosaic_path: Path = info["mosaic"]  # type: ignore[index]
    if mosaic_path.exists():
        image, bounds = natural_color_image(mosaic_path)
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="Landsat natural colour",
            opacity=1.0,
            show=True,
        ).add_to(m)

    ndwi_path: Path = info["water_mask"]  # type: ignore[index]
    if ndwi_path.exists():
        image, bounds = single_band_colormap_image(
            ndwi_path,
            colormap="Blues",
            percentiles=(0, 100),
        )
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="NDWI water mask",
            opacity=0.6,
            show=False,
        ).add_to(m)

    salinity_score_path: Path = info["salinity_score"]  # type: ignore[index]
    if salinity_score_path.exists():
        image, bounds = single_band_colormap_image(
            salinity_score_path,
            colormap="magma",
        )
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="Salinity score",
            opacity=0.6,
            show=False,
        ).add_to(m)

    salinity_class_path: Path = info["salinity_class"]  # type: ignore[index]
    if salinity_class_path.exists():
        image, bounds = classification_image(
            salinity_class_path,
            palette=SALINITY_CLASS_COLORS,
        )
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="Salinity classes",
            opacity=0.7,
            show=False,
        ).add_to(m)

    salinity_water_mask_path: Path = info["salinity_water_mask"]  # type: ignore[index]
    if salinity_water_mask_path.exists():
        image, bounds = single_band_colormap_image(
            salinity_water_mask_path,
            colormap="PuBu",
            percentiles=(0, 100),
        )
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="Salinity water mask",
            opacity=0.6,
            show=False,
        ).add_to(m)

    nlcd_path: Path = info["nlcd"]  # type: ignore[index]
    if nlcd_path.exists():
        image, bounds = classification_image(nlcd_path, nodata_values=(0,))
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="NLCD",
            opacity=0.65,
            show=False,
        ).add_to(m)

    cdl_path: Path = info["cdl"]  # type: ignore[index]
    if cdl_path.exists():
        image, bounds = classification_image(cdl_path)
        folium.raster_layers.ImageOverlay(
            image=image,
            bounds=bounds,
            name="USDA NASS CDL",
            opacity=0.65,
            show=False,
        ).add_to(m)

    folium.LayerControl(collapsed=False).add_to(m)
    folium.map.CustomPane("labels").add_to(m)

    return m

In [None]:
# Discover available scenes and build the interactive map widget
PRODUCTS = discover_products(OUTPUT_ROOT)
if not PRODUCTS:
    raise FileNotFoundError(
        f"No Landsat mosaics found in {OUTPUT_ROOT}. Run the example workflow before using the map."
    )

scene_options = sorted(
    ((info["label"], scene_id) for scene_id, info in PRODUCTS.items()),
    key=lambda item: item[0],
)

selector = widgets.Dropdown(
    options=scene_options,
    description="Scene:",
    layout=widgets.Layout(width="70%"),
)

map_output = widgets.Output()


def refresh_map(*_):
    with map_output:
        map_output.clear_output(wait=True)
        try:
            display(create_map(selector.value))
        except Exception as exc:  # noqa: BLE001
            warnings.warn(f"Unable to render map for {selector.value}: {exc}")
            print(f"Unable to render map for {selector.value}: {exc}")


def handle_change(change):
    if change["name"] == "value" and change["type"] == "change":
        refresh_map()


selector.observe(handle_change, names="value")
refresh_map()
widgets.VBox([selector, map_output])