In [1]:
"""
04_interactive_heatmap.ipynb

Purpose:
- Interactive heatmap UI for Site Suitability App
- Uses precomputed spatial tables
- Driven entirely by app/config.py

NO data aggregation
NO suitability logic
"""

import sys
sys.path.append("/content")

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

from ipywidgets import Dropdown, Checkbox, HBox, VBox, Output

from app.transforms import build_site_week_matrix

from matplotlib.colors import LinearSegmentedColormap




In [2]:
from app.config import (
    DATASETS,
    VARIABLES,
    OVERLAY_MODES,
    APP_DEFAULTS
)

print("Config loaded")


Config loaded


In [3]:
# -----------------------------
# VARIABLE → COLORMAP MAPPING
# -----------------------------

VARIABLE_CMAPS = {
    "suitability": "RdYlGn",
    "suitability_temp": "RdYlGn",
    "suitability_humidity": "RdYlGn",
    "suitability_wind": "RdYlGn",

    "temperature_mean": "coolwarm",
    "temperature_absmin": "coolwarm",
    "temperature_absmax": "coolwarm",

    "humidity_mean": LinearSegmentedColormap.from_list(
        "BeigeToBlue",
        ["#f5f5dc", "#08306b"]
    ),
    "humidity_absmax": LinearSegmentedColormap.from_list(
        "BeigeToBlue",
        ["#f5f5dc", "#08306b"]
    ),

    "wind_mean": "Greens",
    "wind_absmax": "Greens",
}


In [4]:
dataset_dropdown = Dropdown(
    options=[(v["label"], k) for k, v in DATASETS.items()],
    value=APP_DEFAULTS["dataset"],
    description="Dataset:"
)

variable_dropdown = Dropdown(
    options=[(v["label"], k) for k, v in VARIABLES.items()],
    value=APP_DEFAULTS["variable"],
    description="Variable:"
)

overlay_dropdown = Dropdown(
    options=[(v["label"], k) for k, v in OVERLAY_MODES.items()],
    value=APP_DEFAULTS["overlay"],
    description="Overlay:"
)

show_colorbar_checkbox = Checkbox(
    value=APP_DEFAULTS["show_colorbar"],
    description="Show colorbar"
)

controls = VBox([
    dataset_dropdown,
    variable_dropdown,
    overlay_dropdown,
    show_colorbar_checkbox
])

controls


VBox(children=(Dropdown(description='Dataset:', options=(('2018–2024 (Full History)', 'full'), ('2021–2024 (La…

In [5]:
from ipywidgets import IntRangeSlider, SelectMultiple

# -----------------------------
# FOCUS CONTROLS (UI ONLY)
# -----------------------------

# Week range selector (1–52)
week_range_slider = IntRangeSlider(
    value=[1, 52],
    min=1,
    max=52,
    step=1,
    description="Weeks:",
    continuous_update=False,
    layout={"width": "600px"}
)

# Site selector with ALL option
all_sites = sorted(
    pd.read_csv(DATASETS[APP_DEFAULTS["dataset"]]["path"])["site_name"].unique()
)

site_selector = SelectMultiple(
    options=["ALL"] + all_sites,
    value=("ALL",),
    description="Sites:",
    layout={"width": "300px", "height": "220px"}
)

# Update layout to include focus controls
controls = VBox([
    dataset_dropdown,
    variable_dropdown,
    overlay_dropdown,
    show_colorbar_checkbox,
    week_range_slider,
    site_selector
])

controls


VBox(children=(Dropdown(description='Dataset:', options=(('2018–2024 (Full History)', 'full'), ('2021–2024 (La…

In [6]:
def load_dataset(dataset_key: str) -> pd.DataFrame:
    path = DATASETS[dataset_key]["path"]
    df = pd.read_csv(path)
    return df


In [7]:
def build_matrix(df: pd.DataFrame, value_col: str) -> pd.DataFrame:
    matrix = (
        df
        .pivot(index="site_name", columns="week_bin", values=value_col)
        .sort_index()
    )
    return matrix


In [8]:
def plot_heatmap(
    df: pd.DataFrame,
    variable_key: str,
    overlay_key: str,
    show_colorbar: bool
):
    var_cfg = VARIABLES[variable_key]
    value_col = var_cfg["column"]

    matrix = build_matrix(df, value_col)

    plt.figure(figsize=(18, 10))
    im = plt.imshow(
        matrix.values,
        aspect="auto",
        cmap=var_cfg["colormap"],
        vmin=var_cfg["vmin"],
        vmax=var_cfg["vmax"]
    )

    plt.yticks(range(len(matrix.index)), matrix.index)
    plt.xticks(range(len(matrix.columns)), matrix.columns)
    plt.xlabel("Week")
    plt.ylabel("Site")
    plt.title(var_cfg["label"])

    if show_colorbar:
        plt.colorbar(im)

    # -------- Overlay logic --------

    # Rank overlay
    if overlay_key == "rank" and var_cfg["allow_rank_overlay"]:
        # rank per week (column-wise), higher is better unless suitability
        ranks = matrix.rank(
            axis=0,
            method="dense",
            ascending=False
        )

        for i in range(matrix.shape[0]):
            for j in range(matrix.shape[1]):
                r = ranks.iat[i, j]
                if not np.isnan(r):
                    fontweight = "bold" if r == 1 else "normal"
                    plt.text(
                        j, i,
                        int(r),
                        ha="center",
                        va="center",
                        fontsize=7,
                        fontweight=fontweight
                    )

    # Value overlay
    if overlay_key == "value" and var_cfg["allow_value_overlay"]:
        for i in range(matrix.shape[0]):
            for j in range(matrix.shape[1]):
                val = matrix.iat[i, j]
                if not np.isnan(val):
                    plt.text(
                        j, i,
                        f"{val:.1f}",
                        ha="center",
                        va="center",
                        fontsize=6
                    )

    # Winner strip (best site per week)
    if overlay_key == "winner" and var_cfg["allow_winner_strip"]:
        best_sites = matrix.idxmax(axis=0)

        for week_idx, site in enumerate(best_sites):
            if site in matrix.index:
                row_idx = matrix.index.get_loc(site)
                plt.scatter(
                    week_idx,
                    row_idx,
                    color="black",
                    s=12,
                    zorder=5
                )


In [9]:
def refresh_plot(*args):
    with output:
        output.clear_output(wait=True)
        plt.close("all")

        # Load data
        df = load_dataset(dataset_dropdown.value)

        # -----------------------------
        # DERIVE ACTIVE WEEK MASK
        # -----------------------------
        week_start, week_end = week_range_slider.value
        active_weeks = set(range(week_start, week_end + 1))

        # -----------------------------
        # DERIVE ACTIVE SITE MASK
        # -----------------------------
        selected_sites = set(site_selector.value)
        all_sites = set(df["site_name"].unique())

        if "ALL" in selected_sites or len(selected_sites) == 0:
            active_sites = all_sites
        else:
            active_sites = selected_sites

        # -----------------------------
        # Plot (PASS MASKS)
        # -----------------------------
        plot_heatmap(
            df=df,
            variable_key=variable_dropdown.value,
            overlay_key=overlay_dropdown.value,
            show_colorbar=show_colorbar_checkbox.value,
            active_weeks=active_weeks,
            active_sites=active_sites
        )


In [10]:
def plot_heatmap(
    df: pd.DataFrame,
    variable_key: str,
    overlay_key: str,
    show_colorbar: bool,
    active_weeks: set,
    active_sites: set
):
    # -----------------------------
    # BUILD SITE × WEEK MATRIX
    # -----------------------------
    matrix = build_site_week_matrix(
        df=df,
        value_col=VARIABLES[variable_key]["column"]
    )

    # -----------------------------
    # FIGURE / AXES
    # -----------------------------
    fig, ax = plt.subplots(figsize=(18, 10))

    im = ax.imshow(
        matrix.values,
        aspect="auto",
        cmap=VARIABLE_CMAPS[variable_key]
    )

    # -----------------------------
    # AXIS LABELS
    # -----------------------------
    ax.set_yticks(range(len(matrix.index)))
    ax.set_yticklabels(matrix.index)

    ax.set_xticks(range(len(matrix.columns)))
    ax.set_xticklabels(matrix.columns)

    ax.set_xlabel("Week")
    ax.set_ylabel("Site")

    # -----------------------------
    # VISUAL GREYING (FOCUS MASK)
    # -----------------------------
    inactive_alpha = 0.9

    for row_idx, site in enumerate(matrix.index):
        site_active = site in active_sites

        for col_idx, week in enumerate(matrix.columns):
            week_active = week in active_weeks

            if not (site_active and week_active):
                ax.add_patch(
                    plt.Rectangle(
                        (col_idx - 0.5, row_idx - 0.5),
                        1,
                        1,
                        color="#d9d9d9",
                        alpha=inactive_alpha,
                        linewidth=0
                    )
                )

    # -----------------------------
    # RANK OVERLAY
    # -----------------------------
    if overlay_key == "rank":
        rank_col = VARIABLES[variable_key].get("rank_column")

        if rank_col and rank_col in df.columns:
            rank_matrix = build_site_week_matrix(
                df=df,
                value_col=rank_col
            )

            for i, site in enumerate(matrix.index):
                for j, week in enumerate(matrix.columns):
                    rank = rank_matrix.loc[site, week]
                    if pd.notna(rank):
                        ax.text(
                            j,
                            i,
                            int(rank),
                            ha="center",
                            va="center",
                            fontsize=8,
                            fontweight="bold" if rank == 1 else "normal",
                            color="black",
                            zorder=7
                        )

    # -----------------------------
    # WINNER OVERLAY (ALL RANK == 1)
    # -----------------------------
    if overlay_key == "winner":
        rank_col = VARIABLES[variable_key].get("rank_column")

        if rank_col and rank_col in df.columns:
            rank_matrix = build_site_week_matrix(
                df=df,
                value_col=rank_col
            )

            for col_idx, week in enumerate(matrix.columns):
                week_ranks = rank_matrix[week]

                # ALL winners (dense rank == 1)
                winning_sites = week_ranks[week_ranks == 1].index

                for site_name in winning_sites:
                    if site_name in matrix.index:
                        row_idx = list(matrix.index).index(site_name)

                        # Respect masks
                        if site_name in active_sites and week in active_weeks:
                            ax.scatter(
                                col_idx,
                                row_idx,
                                s=180,
                                marker="o",
                                facecolor="black",
                                edgecolor="white",
                                linewidth=1.5,
                                zorder=8
                            )

    # -----------------------------
    # COLORBAR
    # -----------------------------
    if show_colorbar:
        fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)

    plt.tight_layout()
    plt.show()


In [11]:
from ipywidgets import Output, VBox

# -----------------------------
# OUTPUT AREA
# -----------------------------
output = Output()

# -----------------------------
# WIRE CONTROLS TO REFRESH
# -----------------------------
dataset_dropdown.observe(refresh_plot, names="value")
variable_dropdown.observe(refresh_plot, names="value")
overlay_dropdown.observe(refresh_plot, names="value")
show_colorbar_checkbox.observe(refresh_plot, names="value")
week_range_slider.observe(refresh_plot, names="value")
site_selector.observe(refresh_plot, names="value")

# -----------------------------
# INITIAL RENDER
# -----------------------------
refresh_plot()

# -----------------------------
# DISPLAY APP
# -----------------------------
VBox([controls, output])


VBox(children=(VBox(children=(Dropdown(description='Dataset:', options=(('2018–2024 (Full History)', 'full'), …