In [None]:
import sys
from pathlib import Path
sys.path.insert(0, str(Path('../../src')))

from analysis import slope_area_regression_binned, compute_iqr_errors, weighted_r2
from utils import find_project_root
from plotting import set_nature_style


In [None]:
import os
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import LogLocator

from TopoAnalysis import dem as d


# ─────────────────────────────────────────────
# Matplotlib style (Nature-ish)
# ─────────────────────────────────────────────
set_nature_style()


# ─────────────────────────────────────────────
# CONFIG FOR LOCATIONS / TILES
# ─────────────────────────────────────────────
LOC_CONFIGS = {
    "AP": {
        "tile_dir": "tile1", #change to any other tile to see other data
        "min_area": 8.89e3,
    },
    "GM": {
        "tile_dir": "tile2", #change to any other tile to see other data
        "min_area": 5.49e3,
    },
}

# Colors
GRAY_ORIG       = "0.25"
COLOR_PART      = "#3A7D44"   # green
COLOR_AD_NO_OPT = "#1f77b4"   # blue
COLOR_AD_OPT    = "#D55E00"   # orange

MEDIAN_EDGEWIDTH = 0.15  # thinner black outline on median markers


# ─────────────────────────────────────────────
# Path handling (portable): find project root containing data/
# ─────────────────────────────────────────────
PROJECT_ROOT = find_project_root()
DATA_ROOT = PROJECT_ROOT / "data"

print("PROJECT_ROOT:", PROJECT_ROOT)
print("DATA_ROOT   :", DATA_ROOT)

# ─────────────────────────────────────────────
# Helpers: unpack, IQR errors, R²
# ─────────────────────────────────────────────
def unpack_dbg(dbg):
    A_raw = dbg["A"]
    S_raw = dbg["S"]
    bA    = dbg["bin_logA"]
    bS    = dbg["bin_logS"]
    bN    = dbg["bin_counts"]
    A_bin = 10**bA
    S_bin = 10**bS
    return A_raw, S_raw, bA, bS, bN, A_bin, S_bin

# ─────────────────────────────────────────────
# Load + run regressions for one location
# ─────────────────────────────────────────────
def prepare_location(location, tile_dir, min_area,
                     vertical_interval=10, nbins=24):
    """
    Load original + three models for a given location/tile,
    run slope–area regression, return dict of results.

    Expected layout:
      data/<LOCATION>/<tile_dir>/
        - original grids (Elevation/Area/FD) live in this folder
        - modeled grids live in: .../outputs/
    """
    tile_path = DATA_ROOT / location / tile_dir
    outputs_path = tile_path / "outputs"

    # --- sanity checks that won't break portability ---
    if not tile_path.is_dir():
        raise FileNotFoundError(f"Missing tile directory: {tile_path}")
    if not outputs_path.is_dir():
        raise FileNotFoundError(f"Missing outputs directory: {outputs_path}")

    # Original grids (in tile_path)
    elev_prefix = f"{location}_1m_best_tile_filled"
    flow_prefix = f"{location}_1m_best_tile_fd"
    area_prefix = f"{location}_1m_best_tile_area"

    root_elev = str(tile_path / elev_prefix)
    root_flow = str(tile_path / flow_prefix)
    root_area = str(tile_path / area_prefix)

    elev_orig = d.Elevation.load(root_elev)
    area      = d.Area.load(root_area)
    fd        = d.FlowDirectionD8.load(root_flow)

    # Modeled grids (in outputs_path)
    prefix_part      = f"{location}_1m_best_tile_Partitioned-model-a_crit-k-d-opt"
    prefix_AD_no_opt = f"{location}_1m_best_tile_AD-no-opt"
    prefix_AD_opt    = f"{location}_1m_best_tile_AD-opt"

    elev_part      = d.FilledElevation.load(str(outputs_path / (prefix_part      + "_filled")))
    elev_AD_no_opt = d.FilledElevation.load(str(outputs_path / (prefix_AD_no_opt + "_filled")))
    elev_AD_opt    = d.FilledElevation.load(str(outputs_path / (prefix_AD_opt    + "_filled")))

    print(f"\nLoaded grids for {location} {tile_dir}:")
    print("  elev_orig      :", elev_orig._griddata.shape)
    print("  elev_part      :", elev_part._griddata.shape)
    print("  elev_AD_no_opt :", elev_AD_no_opt._griddata.shape)
    print("  elev_AD_opt    :", elev_AD_opt._griddata.shape)
    print("  area           :", area._griddata.shape)

    def run_reg(elev_obj):
        Ks, theta, dbg = slope_area_regression_binned(
            elev_obj=elev_obj,
            area_obj=area,
            fd_obj=fd,
            min_area=min_area,
            vertical_interval=vertical_interval,
            nbins=nbins,
            min_per_bin=10,
            agg="median",
            require_weighted=True,
            relax_if_sparse=True,
        )
        # R² in log–log space
        bA = dbg["bin_logA"]
        bS = dbg["bin_logS"]
        bN = dbg["bin_counts"].astype(float)
        yhat = np.log10(Ks) - theta * bA
        R2 = weighted_r2(bA, bS, yhat, bN)
        return Ks, theta, R2, dbg

    Ks_o, th_o, R2_o, dbg_o = run_reg(elev_orig)
    Ks_n, th_n, R2_n, dbg_n = run_reg(elev_AD_no_opt)
    Ks_f, th_f, R2_f, dbg_f = run_reg(elev_AD_opt)
    Ks_p, th_p, R2_p, dbg_p = run_reg(elev_part)

    print(f"\nTile: {location} {tile_dir}")
    print(f"  ORIGINAL    : Ks = {Ks_o:.3e} m, θ = {th_o:.3f}, R² = {R2_o:.3f}")
    print(f"  AD no-opt   : Ks = {Ks_n:.3e} m, θ = {th_n:.3f}, R² = {R2_n:.3f}")
    print(f"  AD opt      : Ks = {Ks_f:.3e} m, θ = {th_f:.3f}, R² = {R2_f:.3f}")
    print(f"  PARTITIONED : Ks = {Ks_p:.3e} m, θ = {th_p:.3f}, R² = {R2_p:.3f}")

    return {
        "orig":      dict(Ks=Ks_o, theta=th_o, R2=R2_o, dbg=dbg_o),
        "AD_no_opt": dict(Ks=Ks_n, theta=th_n, R2=R2_n, dbg=dbg_n),
        "AD_opt":    dict(Ks=Ks_f, theta=th_f, R2=R2_f, dbg=dbg_f),
        "part":      dict(Ks=Ks_p, theta=th_p, R2=R2_p, dbg=dbg_p),
    }


# ─────────────────────────────────────────────
# Prepare both locations
# ─────────────────────────────────────────────
results = {}
for loc, cfg in LOC_CONFIGS.items():
    results[loc] = prepare_location(loc, cfg["tile_dir"], cfg["min_area"])


# ─────────────────────────────────────────────
# Compute global Y-limits from binned S across all 6 panels
# ─────────────────────────────────────────────
Smins, Smaxs = [], []
for loc in ["AP", "GM"]:
    res_loc = results[loc]
    dbg_o = res_loc["orig"]["dbg"]
    _, _, _, _, _, _, S_bin_o = unpack_dbg(dbg_o)
    for model_key in ["AD_no_opt", "AD_opt", "part"]:
        dbg_m = res_loc[model_key]["dbg"]
        _, _, _, _, _, _, S_bin_m = unpack_dbg(dbg_m)
        Smins.append(np.nanmin(S_bin_o))
        Smins.append(np.nanmin(S_bin_m))
        Smaxs.append(np.nanmax(S_bin_o))
        Smaxs.append(np.nanmax(S_bin_m))

global_ymin = min(Smins) * 0.9
global_ymax = max(Smaxs) * 1.1


# ─────────────────────────────────────────────
# 2×3 FIGURE: rows = models, cols = AP / GM
# ─────────────────────────────────────────────
fig, axes = plt.subplots(
    nrows=3, ncols=2,
    figsize=(6.5, 7.0),
    sharey=True,
)

row_model_keys = ["AD_no_opt", "AD_opt", "part"]
row_labels     = ["AD no-opt", "AD opt", "Partitioned"]
col_locations  = ["AP", "GM"]

for col, loc in enumerate(col_locations):
    tile_dir = LOC_CONFIGS[loc]["tile_dir"]
    res_loc = results[loc]

    for row, (model_key, row_label) in enumerate(zip(row_model_keys, row_labels)):
        ax = axes[row, col]

        orig = res_loc["orig"]
        mod  = res_loc[model_key]

        print(f"\n{loc} {tile_dir} – {row_label}")
        print(f"  ORIG: Ks = {orig['Ks']:.3e}, θ = {orig['theta']:.3f}, R² = {orig['R2']:.3f}")
        print(f"  {row_label}: Ks = {mod['Ks']:.3e}, θ = {mod['theta']:.3f}, R² = {mod['R2']:.3f}")

        dbg_o = orig["dbg"]
        dbg_m = mod["dbg"]

        A_raw_o, S_raw_o, bA_o, bS_o, bN_o, A_bin_o, S_bin_o = unpack_dbg(dbg_o)
        A_raw_m, S_raw_m, bA_m, bS_m, bN_m, A_bin_m, S_bin_m = unpack_dbg(dbg_m)

        err_lo_o, err_hi_o = compute_iqr_errors(dbg_o)
        err_lo_m, err_hi_m = compute_iqr_errors(dbg_m)

        # Fits: A-range from orig + model
        logA_min = min(np.nanmin(bA_o), np.nanmin(bA_m))
        logA_max = max(np.nanmax(bA_o), np.nanmax(bA_m))
        A_fit = np.logspace(logA_min, logA_max, 200)
        S_fit_o = orig["Ks"] * A_fit**(-orig["theta"])
        S_fit_m = mod["Ks"]  * A_fit**(-mod["theta"])

        # Choose model color
        if model_key == "AD_no_opt":
            color_model = COLOR_AD_NO_OPT
        elif model_key == "AD_opt":
            color_model = COLOR_AD_OPT
        else:
            color_model = COLOR_PART

        # --------- PLOTTING ----------
        ax.loglog(
            A_raw_o, S_raw_o,
            'o',
            markersize=0.2,
            alpha=0.15,
            color=GRAY_ORIG,
            markeredgewidth=0.0,
            markeredgecolor="none",
            rasterized=True,
        )

        ax.errorbar(
            A_bin_o, S_bin_o,
            yerr=[err_lo_o, err_hi_o],
            fmt='o',
            markersize=2.5,
            mfc="0.7",
            mec="black",
            mew=MEDIAN_EDGEWIDTH,
            ecolor="black",
            elinewidth=0.3,
            capsize=1.0,
            zorder=3,
        )

        ax.loglog(
            A_fit, S_fit_o,
            color="black",
            lw=0.6,
        )

        ax.loglog(
            A_raw_m, S_raw_m,
            'o',
            markersize=0.7,
            alpha=0.3,
            color=color_model,
            markeredgewidth=0.0,
            markeredgecolor="none",
            rasterized=True,
        )

        ax.errorbar(
            A_bin_m, S_bin_m,
            yerr=[err_lo_m, err_hi_m],
            fmt='o',
            markersize=2.5,
            mfc=color_model,
            mec="black",
            mew=MEDIAN_EDGEWIDTH,
            ecolor=color_model,
            elinewidth=0.3,
            capsize=1.0,
            zorder=3,
        )

        ax.loglog(
            A_fit, S_fit_m,
            color=color_model,
            lw=0.6,
            linestyle=(0, (3, 3)),
        )

        # Axes settings
        ax.xaxis.set_major_locator(LogLocator(base=10))
        ax.yaxis.set_major_locator(LogLocator(base=10))
        ax.xaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1))
        ax.yaxis.set_minor_locator(LogLocator(base=10, subs=np.arange(2, 10) * 0.1))
        ax.minorticks_on()

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        ax.set_ylim(global_ymin, global_ymax)

        if row == 0:
            ax.set_title(f"{loc} {tile_dir}", pad=3)

        if col == 0:
            ax.text(
                0.02, 0.95, row_label,
                transform=ax.transAxes,
                ha="left", va="top",
                fontsize=7,
            )

        if row == 2:
            ax.set_xlabel(r"Drainage area $A$ (m$^2$)")
        else:
            ax.set_xlabel("")

        if col == 0:
            ax.set_ylabel(r"Slope $S$")
        else:
            ax.set_ylabel("")

fig.tight_layout()
plt.show()
