This script will take the 2 stage model trained on FireCCI and apply it to modis 2020-2023, and then compare burned area. 

First we need to aggregate modis up to the 1 degree grid and save that as a parquet file.  We can use this to apply the saved stage 1 model

In [5]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
from collections import defaultdict
from pathlib import Path

import numpy as np
import pandas as pd
import rasterio as rio
from rasterio.transform import from_origin
from rasterio.features import rasterize
from tqdm import tqdm

import geopandas as gpd
from shapely.geometry import Polygon
from pyproj import Transformer

# ----------------------------------------------------------------------
# PATHS
# ----------------------------------------------------------------------
# Monthly 4 km files with predictors + fraction live here:
# FIXED: Added leading slash "/" to make this an absolute path
OUT_DIR = "/explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction"

# Where to save ANNUAL coarse-grid parquet, tifs, shapefiles
PARQUET_DIR    = Path(OUT_DIR) / "parquet_coarse_grids_annual_mcd_analytical"
COARSE_TIF_DIR = Path(OUT_DIR) / "tifs_coarse_grids_annual_mcd_analytical"
COARSE_SHP_DIR = Path(OUT_DIR) / "shp_coarse_grids_annual_mcd_analytical"

os.makedirs(PARQUET_DIR, exist_ok=True)
os.makedirs(COARSE_TIF_DIR, exist_ok=True)
os.makedirs(COARSE_SHP_DIR, exist_ok=True)

# ----------------------------------------------------------------------
# CONSTANTS
# ----------------------------------------------------------------------
WANTED = [
    "DEM",
    "slope",
    "aspect",                     # will match 'aspect' or 'aspectrad' etc.
    "b1",                         # land cover (categorical)
    "relative_humidity",
    "total_precipitation_sum",
    "temperature_2m",
    "temperature_2m_min",
    "temperature_2m_max",
    "build_up_index",
    "drought_code",
    "duff_moisture_code",
    "fine_fuel_moisture_code",
    "fire_weather_index",
    "initial_fire_spread_index",  # if files use 'initial_spread_index', it’ll still match
]

# UPDATED SETTINGS BASED ON OPTIMIZATION WINNER
GRID_SIZES_DEG      = [1]         # ONLY 1 Degree
BURNED_THRESHOLD    = 0.05        # >=5% of 4 km pixels burned -> coarse cell burned
FRACTION_BAND_NAME  = "fraction"  # description set when you made *_with_fraction.tif

# If True: also write a QA GeoTIFF that paints coarse labels back onto the 4 km EPSG:3413 grid
WRITE_QA_LABEL_ON_4KM = False

# ----------------------------------------------------------------------
# HELPERS
# ----------------------------------------------------------------------
def _norm(s: str) -> str:
    return re.sub(r"[^a-z0-9]", "", s.lower())

WANTED_NORM   = [_norm(x) for x in WANTED]
FRACTION_NORM = _norm(FRACTION_BAND_NAME)

# Filenames like: cems_e5l_mcd_2004_7_with_fraction.tif
# FIXED: Matches 'mcd' filenames
name_re = re.compile(r"cems_e5l_mcd_(\d{4})_(\d{1,2})_with_fraction\.tif$", re.IGNORECASE)

def parse_year_month(path: Path):
    m = name_re.search(path.name)
    return (int(m.group(1)), int(m.group(2))) if m else None

def map_band_indices_by_name(ds: rio.DatasetReader):
    mapping = {}
    descs = ds.descriptions  # tuple length = band count; may contain None
    for i, d in enumerate(descs, start=1):
        if d is None:
            d = f"B{i}"
        mapping[_norm(d)] = i
    return mapping, descs

def compute_lonlat_grid(ds: rio.DatasetReader):
    """
    Compute lon/lat center coordinates for each pixel in ds, always returning EPSG:4326 lon/lat.
    Works whether ds is EPSG:3413 (meters) or already EPSG:4326 (degrees), etc.
    """
    h, w = ds.height, ds.width
    rows, cols = np.indices((h, w))
    xs, ys = rio.transform.xy(ds.transform, rows, cols, offset="center")
    x = np.asarray(xs, dtype=np.float64)
    y = np.asarray(ys, dtype=np.float64)

    if ds.crs is None:
        raise RuntimeError("Dataset has no CRS; cannot compute lon/lat.")

    epsg = ds.crs.to_epsg()
    if epsg == 4326:
        lon = x.astype(np.float32)
        lat = y.astype(np.float32)
        return lon, lat

    transformer = Transformer.from_crs(ds.crs, "EPSG:4326", always_xy=True)
    lon, lat = transformer.transform(x, y)
    return lon.astype(np.float32), lat.astype(np.float32)

def mode_ignore_nan(x: pd.Series):
    """Majority value ignoring NaNs. Returns NaN if all are NaN."""
    x = x.dropna()
    if x.empty:
        return np.nan
    return x.value_counts().idxmax()

def aggregate_to_coarse_grids_annual(
    year: int,
    ds: rio.DatasetReader,
    predictors_stack: np.ndarray,
    predictor_names: list,
    annual_frac: np.ndarray,
    lon: np.ndarray,
    lat: np.ndarray,
    grid_sizes_deg=GRID_SIZES_DEG,
    burned_threshold=BURNED_THRESHOLD,
    parquet_dir: Path = PARQUET_DIR,
    coarse_tif_dir: Path = COARSE_TIF_DIR,
    coarse_shp_dir: Path = COARSE_SHP_DIR,
    base_name: str = "",
):
    """
    Aggregate annual 4 km fraction to coarse grids (1 deg only), build binary label,
    assign unique ID per cell, and save outputs.
    """
    H, W = ds.height, ds.width
    N = H * W

    # Flatten
    lon_flat  = lon.ravel()
    lat_flat  = lat.ravel()
    frac_flat = annual_frac.ravel()

    # Binary 4 km input: burned if fraction > 0 (ANY fire)
    binary_4km_flat = np.zeros_like(frac_flat, dtype=np.uint8)
    valid_frac = ~np.isnan(frac_flat)
    
    # -----------------------------------------------------------
    # CRITICAL LOGIC: Using > 0 to capture any fire for input aggregation
    # -----------------------------------------------------------
    binary_4km_flat[valid_frac & (frac_flat > 0)] = 1 
    binary_4km_flat[~valid_frac] = 0  # will be masked via valid_frac

    # Flatten predictors
    pred_flat = {
        name: band.ravel()
        for name, band in zip(predictor_names, predictors_stack)
    }

    # Use only pixels where fraction is not NaN
    valid = valid_frac
    valid_idx = np.nonzero(valid)[0]

    if valid_idx.size == 0:
        print(f"[WARN] Year {year}: no valid annual fraction pixels; skipping coarse grids.")
        return

    # Per-pixel values for valid pixels
    frac_valid = frac_flat[valid]
    bin_valid  = binary_4km_flat[valid]
    lon_valid  = lon_flat[valid]
    lat_valid  = lat_flat[valid]
    pred_valid = {name: arr[valid] for name, arr in pred_flat.items()}

    for size_deg in grid_sizes_deg:
        # Assign each valid pixel to a coarse EPSG:4326 grid cell
        big_lon = size_deg * np.floor(lon_valid / size_deg)
        big_lat = size_deg * np.floor(lat_valid / size_deg)

        df_dict = {
            "big_lon": big_lon.astype(np.float32),
            "big_lat": big_lat.astype(np.float32),
            "burned_4km": bin_valid.astype(np.uint8),
            "frac_4km": frac_valid.astype(np.float32),
            "flat_idx": valid_idx.astype(np.int64),
        }

        for name in predictor_names:
            # Keep b1 as float32 here; mode_ignore_nan will still work.
            df_dict[name] = pred_valid[name].astype(np.float32)

        df = pd.DataFrame(df_dict)

        # Group by coarse cell
        group_cols = ["big_lon", "big_lat"]
        agg_dict = {
            "burned_4km": "mean",  # fraction of 4 km pixels burned in the coarse cell
            "frac_4km": "mean",    # mean annual fraction (diagnostic)
        }

        # new way which is right
        for name in predictor_names:
            if name == "b1":
                agg_dict[name] = mode_ignore_nan  # majority land cover
            elif name in ["relative_humidity", "total_precipitation_sum"]:
                agg_dict[name] = "min"
            elif name in [
                "temperature_2m",
                "temperature_2m_min",
                "temperature_2m_max",
                "build_up_index",
                "drought_code",
                "duff_moisture_code",
                "fine_fuel_moisture_code",
                "fire_weather_index",
                "initial_fire_spread_index",
            ]:
                agg_dict[name] = "max"
            else:
                # Default to mean for DEM, slope, aspect, and any unspecified variables
                agg_dict[name] = "mean"

        grouped = df.groupby(group_cols, as_index=False).agg(agg_dict)

        # Rename burned_4km -> burned_frac_4km for clarity
        grouped = grouped.rename(columns={"burned_4km": "burned_frac_4km"})

        # Coarse burned/unburned label: 1 if >= threshold (0.05) of underlying 4 km pixels burned
        grouped["burned_label"] = (grouped["burned_frac_4km"] >= burned_threshold).astype(np.uint8)

        # Deterministic row order and assign ID 0..N-1
        grouped = grouped.sort_values(["big_lat", "big_lon"]).reset_index(drop=True)
        grouped["ID"] = np.arange(len(grouped), dtype=np.int64)

        # Metadata
        grouped["year"]     = year
        grouped["grid_deg"] = size_deg

        # Save Parquet: one row per coarse cell
        parquet_name = f"{base_name}_grid{size_deg}deg.parquet"
        parquet_path = parquet_dir / parquet_name
        grouped.to_parquet(parquet_path, index=False)
        print(f"[PARQUET] Saved {parquet_path}")

        # ------------------------------------------------------------------
        # GeoTIFF (COARSE GRID): EPSG:4326 at size_deg resolution
        # ------------------------------------------------------------------
        min_lon = float(grouped["big_lon"].min())
        max_lon = float(grouped["big_lon"].max()) + float(size_deg)
        min_lat = float(grouped["big_lat"].min())
        max_lat = float(grouped["big_lat"].max()) + float(size_deg)

        transform = from_origin(min_lon, max_lat, float(size_deg), float(size_deg))
        width  = int(np.ceil((max_lon - min_lon) / float(size_deg)))
        height = int(np.ceil((max_lat - min_lat) / float(size_deg)))

        shapes = []
        for lon0, lat0, lab in zip(grouped["big_lon"], grouped["big_lat"], grouped["burned_label"]):
            lon1 = float(lon0) + float(size_deg)
            lat1 = float(lat0) + float(size_deg)
            poly = Polygon([(lon0, lat0), (lon1, lat0), (lon1, lat1), (lon0, lat1)])
            shapes.append((poly, int(lab)))

        coarse_raster = rasterize(
            shapes=shapes,
            out_shape=(height, width),
            transform=transform,
            fill=255,           # nodata
            dtype="uint8",
            all_touched=False
        )

        coarse_profile = {
            "driver": "GTiff",
            "height": height,
            "width": width,
            "count": 1,
            "dtype": "uint8",
            "crs": "EPSG:4326",
            "transform": transform,
            "nodata": 255,
            "compress": "LZW",
            "tiled": True,
            "blockxsize": 256,
            "blockysize": 256,
            "BIGTIFF": "IF_SAFER",
        }

        tif_name = f"{base_name}_grid{size_deg}deg_epsg4326_burned_unburned.tif"
        tif_path = coarse_tif_dir / tif_name

        with rio.open(tif_path, "w", **coarse_profile) as dst:
            dst.write(coarse_raster, 1)

        print(f"[TIF] Saved {tif_path} (EPSG:4326, {size_deg}°)")

        # ------------------------------------------------------------------
        # OPTIONAL QA GeoTIFF: paint coarse labels back onto original 4 km grid
        # ------------------------------------------------------------------
        if WRITE_QA_LABEL_ON_4KM:
            label_map = grouped[["big_lon", "big_lat", "burned_label"]].copy()
            df_lbl = df.merge(label_map, on=["big_lon", "big_lat"], how="left")

            coarse_label_flat = np.full(N, 255, dtype=np.uint8)  # nodata=255
            coarse_label_flat[df_lbl["flat_idx"].to_numpy()] = (
                df_lbl["burned_label"].to_numpy().astype(np.uint8)
            )
            coarse_label_4km = coarse_label_flat.reshape(H, W)

            profile_4km = ds.profile.copy()
            profile_4km.update(
                dtype="uint8",
                count=1,
                compress="LZW",
                tiled=True,
                blockxsize=256,
                blockysize=256,
                BIGTIFF="IF_SAFER",
                nodata=255,
            )

            tif_name_4km = f"{base_name}_grid{size_deg}deg_label_on4km_epsg{ds.crs.to_epsg() if ds.crs else 'unknown'}.tif"
            tif_path_4km = coarse_tif_dir / tif_name_4km

            with rio.open(tif_path_4km, "w", **profile_4km) as dst:
                dst.write(coarse_label_4km, 1)

            print(f"[TIF-QA] Saved {tif_path_4km} (label on original grid)")

        # ------------------------------------------------------------------
        # Shapefile: one polygon per coarse cell, attributes: ID + burned_label
        # ------------------------------------------------------------------
        geoms = []
        ids    = grouped["ID"].to_numpy()
        labels = grouped["burned_label"].to_numpy()

        for lon0, lat0 in zip(grouped["big_lon"], grouped["big_lat"]):
            lon1 = float(lon0) + float(size_deg)
            lat1 = float(lat0) + float(size_deg)
            poly = Polygon([
                (lon0, lat0),
                (lon1, lat0),
                (lon1, lat1),
                (lon0, lat1),
                (lon0, lat0),
            ])
            geoms.append(poly)

        shp_gdf = gpd.GeoDataFrame(
            {"ID": ids, "burned_label": labels},
            geometry=geoms,
            crs="EPSG:4326",
        )

        shp_name = f"{base_name}_grid{size_deg}deg_cells_epsg4326.shp"
        shp_path = coarse_shp_dir / shp_name
        shp_gdf.to_file(shp_path)
        print(f"[SHP] Saved {shp_path} (EPSG:4326)")

# ----------------------------------------------------------------------
# MAIN: BUILD ANNUAL FROM MONTHLY *_with_fraction.tif
# ----------------------------------------------------------------------
# FIXED: Looking for 'mcd' files in the absolute path
monthly_tifs = sorted(Path(OUT_DIR).glob("cems_e5l_mcd_*_with_fraction.tif"))
if not monthly_tifs:
    print(f"No monthly _with_fraction.tif files found in {OUT_DIR}")
    raise SystemExit

# Group monthly files by year
year_to_paths = defaultdict(list)
for p in monthly_tifs:
    ym = parse_year_month(p)
    if ym is None:
        print(f"[SKIP name] {p.name}")
        continue
    year, month = ym
    year_to_paths[year].append((month, p))

for year in sorted(year_to_paths.keys()):
    month_paths = sorted(year_to_paths[year], key=lambda x: x[0])
    print(f"\n[YEAR] {year} — {len(month_paths)} monthly files")

    # Use first month's file as template for grid, CRS, etc.
    first_month, first_path = month_paths[0]
    with rio.open(first_path) as ds_template:
        H, W = ds_template.height, ds_template.width
        band_map, descs = map_band_indices_by_name(ds_template)

        # Figure out predictor band indices and fraction band index
        predictor_indices = []
        predictor_names   = []

        for want_norm, want_orig in zip(WANTED_NORM, WANTED):
            if want_norm in band_map:
                predictor_indices.append(band_map[want_norm])
                predictor_names.append(want_orig)
                continue
            # partial match (handles 'aspect' vs 'aspectrad', etc.)
            match_idx = None
            for k_norm, idx in band_map.items():
                if want_norm in k_norm or k_norm in want_norm:
                    match_idx = idx
                    break
            if match_idx is not None:
                predictor_indices.append(match_idx)
                predictor_names.append(want_orig)
            else:
                print(f"[WARN] {first_path.name}: could not find band like '{want_orig}'")

        if FRACTION_NORM not in band_map:
            raise RuntimeError(f"{first_path} has no band named/desc like '{FRACTION_BAND_NAME}'")

        frac_idx = band_map[FRACTION_NORM]

        if not predictor_indices:
            print(f"[SKIP no predictors for year {year}]")
            continue

        # Prepare storage for monthly stacks
        frac_months = []  # list of (H, W)
        pred_months = {name: [] for name in predictor_names}

        # Read all months for this year
        for month, path in month_paths:
            with rio.open(path) as ds_m:
                if ds_m.height != H or ds_m.width != W:
                    raise ValueError(
                        f"Shape mismatch for {path}: expected {(H, W)}, got {(ds_m.height, ds_m.width)}"
                    )

                # predictors
                for name, idx in zip(predictor_names, predictor_indices):
                    arr = ds_m.read(idx).astype(np.float32)
                    pred_months[name].append(arr)

                # fraction
                frac_arr = ds_m.read(frac_idx).astype(np.float32)
                frac_months.append(frac_arr)

        # Annual fraction = max over months
        frac_stack = np.stack(frac_months, axis=0)          # (n_months, H, W)
        annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)

        # Annual predictors = mean over months per pixel
        predictor_arrays = []
        for name in predictor_names:
            stack = np.stack(pred_months[name], axis=0)     # (n_months, H, W)
            annual_pred = np.nanmean(stack, axis=0).astype(np.float32)
            predictor_arrays.append(annual_pred)

        predictors_stack = np.stack(predictor_arrays, axis=0)  # (n_predictors, H, W)

        # lon/lat grid (EPSG:4326) computed from template CRS
        lon, lat = compute_lonlat_grid(ds_template)

        # Aggregate to coarse annual grids
        # FIXED: Naming outputs with 'mcd'
        base_name = f"cems_e5l_mcd_{year}_annual"
        aggregate_to_coarse_grids_annual(
            year=year,
            ds=ds_template,
            predictors_stack=predictors_stack,
            predictor_names=predictor_names,
            annual_frac=annual_frac,
            lon=lon,
            lat=lat,
            grid_sizes_deg=GRID_SIZES_DEG,
            burned_threshold=BURNED_THRESHOLD,
            parquet_dir=PARQUET_DIR,
            coarse_tif_dir=COARSE_TIF_DIR,
            coarse_shp_dir=COARSE_SHP_DIR,
            base_name=base_name,
        )

print("\n[DONE] Analytical coarse grids (1 degree, 0.05 threshold) created.")


[YEAR] 2001 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2001_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2001_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)
  ogr_write(


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2001_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2002 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2002_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2002_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2002_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2003 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2003_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2003_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2003_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2004 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2004_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2004_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2004_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2005 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2005_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2005_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2005_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2006 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2006_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2006_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2006_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2007 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2007_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2007_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2007_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2008 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2008_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2008_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2008_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2009 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2009_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2009_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2009_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2010 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2010_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2010_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2010_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2011 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2011_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2011_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2011_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2012 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2012_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2012_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2012_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2013 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2013_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2013_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2013_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2014 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2014_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2014_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2014_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2015 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2015_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2015_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2015_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2016 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2016_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2016_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2016_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2017 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2017_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2017_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2017_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2018 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2018_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2018_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2018_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2019 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2019_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2019_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2019_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2020 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2020_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2020_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2020_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2021 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2021_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2021_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2021_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2022 — 12 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2022_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2022_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2022_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[YEAR] 2023 — 11 monthly files


  annual_frac = np.nanmax(frac_stack, axis=0)         # (H, W)
  annual_pred = np.nanmean(stack, axis=0).astype(np.float32)


[PARQUET] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2023_annual_grid1deg.parquet
[TIF] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/tifs_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2023_annual_grid1deg_epsg4326_burned_unburned.tif (EPSG:4326, 1°)


  shp_gdf.to_file(shp_path)


[SHP] Saved /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical/cems_e5l_mcd_2023_annual_grid1deg_cells_epsg4326.shp (EPSG:4326)

[DONE] Analytical coarse grids (1 degree, 0.05 threshold) created.


Now apply the stage 1 model to the parquet files

In [6]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Predict Stage-1 burned/unburned on MCD annual 1° parquet (analytical).
- Uses Model trained on FireCCI (MODEL_ROOT).
- Joins to MCD annual 1° shapefiles by ID.
- Saves MCD shapefiles with TP/FN/TN/FP labels.
- Generates summary CSV and 4-panel plot for MCD years.

ALL OUTPUTS are suffixed with '_analytical'.
"""

import re
from pathlib import Path

import numpy as np
import pandas as pd
import geopandas as gpd
import joblib

import matplotlib.pyplot as plt

# ---------------------------------------------------------------------
# PATHS
# ---------------------------------------------------------------------

# 1. WHERE THE DATA IS (MCD)
DATA_ROOT = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "training_e5l_cems_mcd_with_fraction"
)

# 2. WHERE THE TRAINED MODEL IS (FireCCI)
# We load the model trained on FireCCI to apply it to MCD
MODEL_ROOT = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "training_e5l_cems_firecci_with_fraction"
)

# INPUTS (MCD Analytical Data)
# Note: These folders must exist from your previous aggregation script
PARQUET_DIR = DATA_ROOT / "parquet_coarse_grids_annual_mcd_analytical"
OBS_SHP_DIR = DATA_ROOT / "shp_coarse_grids_annual_mcd_analytical"

# MODEL FILES (From FireCCI Stage 1)
MODEL_DIR   = MODEL_ROOT / "stage_1_model_analytical"
MODEL_PATH  = MODEL_DIR / "lgbm_stage1_model.joblib"
THRESH_CSV  = MODEL_DIR / "threshold_metrics.csv"
THRESH_TXT  = MODEL_DIR / "final_metrics.txt"

DEFAULT_THRESHOLD = 0.5

# OUTPUTS (Saved inside the MCD folder)
OUT_DIR_BASE = DATA_ROOT / "stage_1_predictions_on_mcd_analytical"
OUT_SHP_DIR  = OUT_DIR_BASE / "pred_vs_obs_shapefiles_annual"
OUT_SHP_DIR.mkdir(parents=True, exist_ok=True)

SUMMARY_CSV = OUT_DIR_BASE / "pred_obs_counts_and_pct_by_year_mcd_analytical.csv"
SUMMARY_PNG = OUT_DIR_BASE / "pred_obs_percent_by_year_4panel_mcd_analytical.png"

# ---------------------------------------------------------------------
# CONFIG
# ---------------------------------------------------------------------
FEATURES = [
    "DEM",
    "slope",
    "aspect",
    "b1",
    "relative_humidity",
    "total_precipitation_sum",
    "temperature_2m",
    "temperature_2m_min",
    "temperature_2m_max",
    "build_up_index",
    "drought_code",
    "duff_moisture_code",
    "fine_fuel_moisture_code",
    "fire_weather_index",
    "initial_fire_spread_index",
]

OBS_LABEL_CANDIDATES = [
    "burned_label",
    "burned_lab",
    "burn_label",
    "burned",
    "label",
    "obs_label",
    "class",
]

# ---------------------------------------------------------------------
# HELPERS
# ---------------------------------------------------------------------
# UPDATED REGEX FOR MCD
parq_re = re.compile(r"cems_e5l_mcd_(\d{4})_annual_grid1deg\.parquet$", re.IGNORECASE)
shp_re  = re.compile(r"cems_e5l_mcd_(\d{4})_annual_grid1deg_cells_epsg4326\.shp$", re.IGNORECASE)


def load_best_threshold() -> float:
    if THRESH_TXT.exists():
        try:
            txt = THRESH_TXT.read_text().splitlines()
            json_start = None
            for i, line in enumerate(txt):
                if line.strip().startswith("{"):
                    json_start = i
                    break
            if json_start is not None:
                import json
                d = json.loads("\n".join(txt[json_start:]))
                return float(d.get("threshold", DEFAULT_THRESHOLD))
        except Exception:
            pass

    if THRESH_CSV.exists():
        try:
            df = pd.read_csv(THRESH_CSV)
            df = df.sort_values(["recall", "precision", "f1"], ascending=False)
            return float(df.iloc[0]["threshold"])
        except Exception:
            pass

    return float(DEFAULT_THRESHOLD)


def ensure_b1_category(X: pd.DataFrame) -> pd.DataFrame:
    X = X.copy()
    X["b1"] = X["b1"].astype("Int64").astype("category")
    return X


def find_year_parquets(parquet_dir: Path):
    out = {}
    if not parquet_dir.exists():
        return out
    for p in parquet_dir.glob("*_grid1deg.parquet"):
        m = parq_re.search(p.name)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items()))


def find_year_shapefiles(shp_dir: Path):
    out = {}
    if not shp_dir.exists():
        return out
    for p in shp_dir.glob("*.shp"):
        m = shp_re.search(p.name)
        if m:
            out[int(m.group(1))] = p
    return dict(sorted(out.items()))


def pick_obs_label_column(gdf: gpd.GeoDataFrame) -> str:
    cols = list(gdf.columns)
    cols_lower = {c.lower(): c for c in cols}

    for cand in OBS_LABEL_CANDIDATES:
        if cand.lower() in cols_lower:
            return cols_lower[cand.lower()]

    for c in cols:
        cl = c.lower()
        if ("burn" in cl and "lab" in cl) or cl in ("burn", "burned"):
            return c

    return ""


def label_tpfn_tnfp(obs: np.ndarray, pred: np.ndarray) -> np.ndarray:
    obs = obs.astype(np.uint8)
    pred = pred.astype(np.uint8)
    out = np.empty(obs.shape[0], dtype=object)
    out[(pred == 1) & (obs == 1)] = "TP"
    out[(pred == 0) & (obs == 1)] = "FN"
    out[(pred == 0) & (obs == 0)] = "TN"
    out[(pred == 1) & (obs == 0)] = "FP"
    return out


def plot_percent_4panel(df_counts: pd.DataFrame, out_png: Path):
    """
    df_counts columns:
      year, TP_pct, FN_pct, TN_pct, FP_pct

    Floating y-axis: each panel auto-scales independently.
    """
    dfp = df_counts.sort_values("year").copy()
    years = dfp["year"].to_numpy()

    fig, axes = plt.subplots(2, 2, figsize=(12, 7), sharex=True)
    axes = axes.ravel()

    panels = [
        ("TP_pct", "TP (%)"),
        ("FN_pct", "FN (%)"),
        ("TN_pct", "TN (%)"),
        ("FP_pct", "FP (%)"),
    ]

    for ax, (col, title) in zip(axes, panels):
        ax.plot(years, dfp[col].to_numpy(), marker="o")
        ax.set_title(title)
        ax.set_xlabel("Year")
        ax.set_ylabel("Percent")
        ax.autoscale(enable=True, axis="y")
        ax.grid(True)

    plt.tight_layout()
    fig.savefig(out_png, dpi=200, bbox_inches="tight")
    plt.close(fig)


# ---------------------------------------------------------------------
# MAIN
# ---------------------------------------------------------------------
def main():
    if not MODEL_PATH.exists():
        raise FileNotFoundError(f"Model not found: {MODEL_PATH}")

    model = joblib.load(MODEL_PATH)
    thr = load_best_threshold()

    print(f"[MODEL] {MODEL_PATH}")
    print(f"[THR]   {thr:.3f}")

    year_to_parq = find_year_parquets(PARQUET_DIR)
    year_to_shp  = find_year_shapefiles(OBS_SHP_DIR)

    years = sorted(set(year_to_parq) & set(year_to_shp))
    
    print(f"Looking for Parquets in: {PARQUET_DIR}")
    print(f"Looking for Shapefiles in: {OBS_SHP_DIR}")
    
    if not years:
        raise RuntimeError(
            "No overlapping years between parquet and shapefiles.\n"
            f"Parquet years found: {list(year_to_parq.keys())}\n"
            f"SHP years found: {list(year_to_shp.keys())}"
        )

    print(f"[YEARS] {years}")

    summary_rows = []

    for year in years:
        parq_path = year_to_parq[year]
        shp_path  = year_to_shp[year]

        print(f"\n=== {year} ===")
        print(f"[PARQ] {parq_path.name}")
        print(f"[SHP ] {shp_path.name}")

        # --- predict on parquet ---
        dfp = pd.read_parquet(parq_path, columns=["ID"] + FEATURES).copy()
        dfp = dfp.dropna(subset=FEATURES).copy()

        X = ensure_b1_category(dfp[FEATURES])
        prob = model.predict_proba(X)[:, 1].astype(np.float32)
        pred = (prob >= thr).astype(np.uint8)

        pred_df = pd.DataFrame(
            {
                "ID": dfp["ID"].astype(np.int64).to_numpy(),
                "pred_prob": prob,
                "pred_label": pred,
            }
        )

        # --- read observed shapefile ---
        gdf = gpd.read_file(shp_path)

        if "ID" not in gdf.columns:
            raise RuntimeError(f"[{year}] Shapefile missing 'ID' column: {shp_path}")

        obs_col = pick_obs_label_column(gdf)
        if not obs_col:
            raise RuntimeError(
                f"[{year}] Could not find an observed label column in {shp_path}\n"
                f"Available columns: {list(gdf.columns)}\n"
                f"Tried candidates: {OBS_LABEL_CANDIDATES}"
            )
        print(f"[OBS] Using observed label column: '{obs_col}'")

        gdf["ID"] = gdf["ID"].astype(np.int64)

        obs_vals = pd.to_numeric(gdf[obs_col], errors="coerce")
        gdf["obs_label"] = obs_vals  # float with NaNs
        valid_obs = gdf["obs_label"].isin([0, 1])

        # --- join ---
        gdf = gdf.merge(pred_df, on="ID", how="left", validate="one_to_one")

        missing_pred = int(gdf["pred_label"].isna().sum())
        if missing_pred:
            print(f"[WARN] {missing_pred:,} polygons had no matching prediction by ID")

        # --- label TP/FN/TN/FP/NA ---
        gdf["pred_obs"] = "NA"
        valid = valid_obs & (~gdf["pred_label"].isna())

        if valid.any():
            gdf.loc[valid, "pred_label"] = gdf.loc[valid, "pred_label"].astype(np.uint8)
            gdf.loc[valid, "pred_obs"] = label_tpfn_tnfp(
                gdf.loc[valid, "obs_label"].astype(np.uint8).to_numpy(),
                gdf.loc[valid, "pred_label"].to_numpy(),
            )

        # --- counts + percents (percents exclude NA) ---
        vc = gdf["pred_obs"].value_counts().to_dict()
        tp = int(vc.get("TP", 0))
        fn = int(vc.get("FN", 0))
        tn = int(vc.get("TN", 0))
        fp = int(vc.get("FP", 0))
        na = int(vc.get("NA", 0))
        denom = tp + fn + tn + fp  # VALID comparisons only

        def pct(x):
            return float(100.0 * x / denom) if denom > 0 else 0.0

        row = {
            "year": int(year),
            "TP": tp,
            "FN": fn,
            "TN": tn,
            "FP": fp,
            "NA": na,
            "n_total": int(len(gdf)),
            "n_valid": int(denom),
            "TP_pct": pct(tp),
            "FN_pct": pct(fn),
            "TN_pct": pct(tn),
            "FP_pct": pct(fp),
        }
        summary_rows.append(row)

        print(f"[COUNTS] TP={tp:,} FN={fn:,} TN={tn:,} FP={fp:,} NA={na:,} (valid={denom:,})")
        print(f"[PCT]    TP={row['TP_pct']:.2f}% FN={row['FN_pct']:.2f}% TN={row['TN_pct']:.2f}% FP={row['FP_pct']:.2f}%")

        # ensure year column
        if "year" not in gdf.columns:
            gdf["year"] = int(year)

        # --- write shapefile (Analytical) ---
        out_name = f"cems_e5l_mcd_{year}_annual_grid1deg_pred_vs_obs_analytical.shp"
        out_path = OUT_SHP_DIR / out_name
        gdf.to_file(out_path)
        print(f"[SAVE] {out_path}")

    # -----------------------------------------------------------------
    # Save summary dataframe + plot percents (Analytical)
    # -----------------------------------------------------------------
    if summary_rows:
        df_sum = pd.DataFrame(summary_rows).sort_values("year").reset_index(drop=True)
        df_sum.to_csv(SUMMARY_CSV, index=False)
        print(f"\n[SUMMARY] Saved CSV:  {SUMMARY_CSV}")

        plot_percent_4panel(df_sum[["year", "TP_pct", "FN_pct", "TN_pct", "FP_pct"]], SUMMARY_PNG)
        print(f"[PLOT]    Saved PNG:  {SUMMARY_PNG}")

        print(f"\n[DONE] Wrote {len(years)} annual MCD shapefiles to {OUT_SHP_DIR}")
    else:
        print("\n[WARN] No years processed.")


if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


[MODEL] /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_firecci_with_fraction/stage_1_model_analytical/lgbm_stage1_model.joblib
[THR]   0.100
Looking for Parquets in: /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/parquet_coarse_grids_annual_mcd_analytical
Looking for Shapefiles in: /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/shp_coarse_grids_annual_mcd_analytical
[YEARS] [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023]

=== 2001 ===
[PARQ] cems_e5l_mcd_2001_annual_grid1deg.parquet
[SHP ] cems_e5l_mcd_2001_annual_grid1deg_cells_epsg4326.shp
[OBS] Using observed label column: 'burned_lab'
[WARN] 5,827 polygons had no matching prediction by ID
[COUNTS] TP=137 FN=41 TN=3,482 FP=1,831 NA=5,827 (valid=5,491)
[PCT]    TP=2.49% FN=0.75% TN=63.41% FP=33.35%
[SAVE] /explore/nobackup/people/spotter5/clell

Now take the cells we predicted as burnable and extract 4km predictor data per year and month and save to parquet file

In [4]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
import re
from pathlib import Path

import numpy as np
import pandas as pd
import rasterio as rio
from rasterio.features import rasterize
from rasterio.warp import transform as rio_transform
import geopandas as gpd
from tqdm import tqdm

import pyarrow as pa
import pyarrow.parquet as pq

# ================== CONFIG ==================
# 1. INPUT DATA (MCD)
IN_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction")

# 2. INPUT SHAPEFILES (MCD Analytical Predictions from previous step)
# These are used to create the spatial mask (only keeping cells predicted as 'burned')
PRED_SHP_DIR = IN_DIR / "stage_1_predictions_on_mcd_analytical" / "pred_vs_obs_shapefiles_annual"

# 3. OUTPUT DATASET (MCD Parquet)
OUT_DATASET_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/parquet_cems_mcd_with_fraction_dataset_burnedlab_mask_analytical")
OUT_DATASET_DIR.mkdir(parents=True, exist_ok=True)

REPROJECT_TO_EPSG4326 = True

# Years to process
YEAR_MIN = 2001
YEAR_MAX = 2023

# Mask shapefile criterion: keep only burned-lab cells
BURNED_LAB_VALUE = 1
BURNED_LAB_FIELD_OVERRIDE = None  # set if you know exact field name

# Fraction band description/name candidates (searched in ds.descriptions)
FRACTION_BAND_DESC_CANDIDATES = ["fraction", "frac", "burn_fraction"]

# Pixel label from fraction
PIXEL_BURN_THRESHOLD = 0.5  # burned if fraction > 0.5, unburned if fraction < 0.5

# ================== HELPERS ==================
def sanitize_names(names):
    """Make unique, safe column names (avoid duplicates)."""
    seen = {}
    out = []
    for n in names:
        if n is None or str(n).strip() == "":
            n = "band"
        n0 = re.sub(r"[^a-zA-Z0-9_]", "_", str(n).strip())
        n0 = re.sub(r"_+", "_", n0).strip("_")
        if n0 == "":
            n0 = "band"
        if n0 in seen:
            seen[n0] += 1
            n0 = f"{n0}_{seen[n0]}"
        else:
            seen[n0] = 1
        out.append(n0)
    return out

# UPDATED REGEX FOR MCD TIFFS
name_re = re.compile(r"cems_e5l_mcd_(\d{4})_(\d{1,2})_with_fraction\.tif$", re.IGNORECASE)

# UPDATED REGEX FOR MCD SHAPEFILES
# Matches files like: cems_e5l_mcd_2001_annual_grid1deg_pred_vs_obs_analytical.shp
shp_re = re.compile(r"cems_e5l_mcd_(\d{4})_annual_grid1deg_pred_vs_obs_analytical\.shp$", re.IGNORECASE)

def parse_year_month(fname: str):
    m = name_re.search(fname)
    if not m:
        return None, None
    return int(m.group(1)), int(m.group(2))

def append_chunk_to_dataset(df: pd.DataFrame, root: Path):
    if not df.columns.is_unique:
        dups = df.columns[df.columns.duplicated()].tolist()
        raise ValueError(f"Duplicate column names found: {dups}")
    table = pa.Table.from_pandas(df, preserve_index=False)
    pq.write_to_dataset(
        table,
        root_path=str(root),
        partition_cols=["year", "month"],
        use_dictionary=False
    )

def find_fraction_band_index(ds: rio.DatasetReader) -> int:
    """
    Return 0-based band index for fraction band by inspecting ds.descriptions.
    """
    descs = list(ds.descriptions) if ds.descriptions else [None] * ds.count
    descs_safe = sanitize_names([d if d else f"B{i}" for i, d in enumerate(descs, start=1)])
    descs_safe_lower = [d.lower() for d in descs_safe]

    for cand in FRACTION_BAND_DESC_CANDIDATES:
        cand = cand.lower()
        for i, d in enumerate(descs_safe_lower):
            if cand == d or cand in d:
                return i

    raise RuntimeError(
        "Could not find fraction band by description. "
        f"Band descriptions (sanitized): {descs_safe}"
    )

def build_lonlat(ds: rio.DatasetReader, xs, ys):
    if (
        REPROJECT_TO_EPSG4326
        and ds.crs is not None
        and ds.crs.to_string().upper() not in ("EPSG:4326", "OGC:CRS84")
    ):
        lons, lats = rio_transform(ds.crs, "EPSG:4326", xs, ys)
        return np.asarray(lons, dtype=np.float64), np.asarray(lats, dtype=np.float64)
    return xs.astype(np.float64), ys.astype(np.float64)

def find_burned_lab_field(gdf: gpd.GeoDataFrame) -> str:
    """
    Find the 'burned_lab' field even if DBF truncates it.
    """
    if BURNED_LAB_FIELD_OVERRIDE:
        if BURNED_LAB_FIELD_OVERRIDE not in gdf.columns:
            raise RuntimeError(f"Override burned-lab field '{BURNED_LAB_FIELD_OVERRIDE}' not in: {list(gdf.columns)}")
        return BURNED_LAB_FIELD_OVERRIDE

    cols_lower = {c.lower(): c for c in gdf.columns}

    # common names
    candidates = ["burned_lab", "burned_label", "burnedlab", "burn_lab", "burnlab", "burned"]
    for c in candidates:
        if c in cols_lower:
            return cols_lower[c]

    # fuzzy fallback
    for c in gdf.columns:
        cl = c.lower()
        if "burn" in cl and ("lab" in cl or "label" in cl):
            return c

    raise RuntimeError(f"Could not find burned_lab field. Columns: {list(gdf.columns)}")

def raster_mask_from_burnedlab(ds: rio.DatasetReader, shp_path: Path) -> np.ndarray:
    """
    Rasterize polygons where burned_lab==1 onto ds grid -> boolean mask (H,W).
    """
    gdf = gpd.read_file(shp_path)
    lab_col = find_burned_lab_field(gdf)

    lab_vals = pd.to_numeric(gdf[lab_col], errors="coerce")
    gdf_keep = gdf.loc[lab_vals == BURNED_LAB_VALUE].copy()

    if gdf_keep.empty:
        return np.zeros((ds.height, ds.width), dtype=bool)

    if ds.crs is None:
        raise RuntimeError(f"Raster has no CRS; cannot rasterize: {shp_path}")
    if gdf_keep.crs is None:
        raise RuntimeError(f"Shapefile has no CRS; cannot rasterize: {shp_path}")

    if gdf_keep.crs != ds.crs:
        gdf_keep = gdf_keep.to_crs(ds.crs)

    shapes = [(geom, 1) for geom in gdf_keep.geometry if geom is not None and not geom.is_empty]
    if not shapes:
        return np.zeros((ds.height, ds.width), dtype=bool)

    mask_u8 = rasterize(
        shapes=shapes,
        out_shape=(ds.height, ds.width),
        transform=ds.transform,
        fill=0,
        dtype="uint8",
        all_touched=False,
    )
    return mask_u8.astype(bool)

# ================== MAIN ==================
def main():
    # UPDATED GLOB FOR MCD
    tifs = sorted(IN_DIR.glob("cems_e5l_mcd_*_with_fraction.tif"))
    if not tifs:
        raise FileNotFoundError(f"No monthly _with_fraction.tif found in {IN_DIR}")

    # Filter to years 2001-2019 only
    todo = []
    for tif in tifs:
        y, m = parse_year_month(tif.name)
        if y is None:
            continue
        if y < YEAR_MIN or y > YEAR_MAX:
            continue
        todo.append((y, m, tif))
    todo.sort()

    if not todo:
        raise RuntimeError(f"No TIFFs found in year range {YEAR_MIN}-{YEAR_MAX}")

    # Cache the rasterized burned-lab mask per year
    year_mask_cache = {}

    canonical_cols = None

    # Global ratio counters (only where burned_pixel is defined)
    burned_total = 0
    unburned_total = 0
    valid_lab_total = 0

    print(f"Scanning for analytical shapefiles in: {PRED_SHP_DIR}")

    for year, month, tif in tqdm(todo, desc="Building partitioned Parquet dataset (burned_lab mask)"):
        # UPDATED: Filename format for MCD analytical shapefiles
        shp_name = f"cems_e5l_mcd_{year}_annual_grid1deg_pred_vs_obs_analytical.shp"
        shp_path = PRED_SHP_DIR / shp_name
        
        if not shp_path.exists():
            print(f"\n[SKIP] {tif.name} (missing annual analytical shapefile: {shp_path})")
            continue

        with rio.open(tif) as ds:
            # band names
            band_names = list(ds.descriptions) if ds.descriptions else []
            if not any(band_names):
                band_names = [f"B{i}" for i in range(1, ds.count + 1)]
            safe_names = sanitize_names(band_names)

            # fraction band index (0-based)
            frac_band0 = find_fraction_band_index(ds)
            frac_col_name = "fraction"

            # burned-lab mask per year (rasterized once)
            if year not in year_mask_cache:
                mask = raster_mask_from_burnedlab(ds, shp_path)
                year_mask_cache[year] = mask
                print(f"\n[YEAR {year}] burned_lab mask keeps {mask.sum():,} / {mask.size:,} pixels ({100*mask.mean():.2f}%)")
            else:
                mask = year_mask_cache[year]
                if mask.shape != (ds.height, ds.width):
                    raise RuntimeError(f"Mask shape mismatch for {year}: mask {mask.shape} vs raster {(ds.height, ds.width)}")

            if mask.sum() == 0:
                continue

            # Read raster (bands, H, W)
            data = ds.read().astype(np.float32)
            bands, h, w = data.shape

            # Flatten to (pixels, bands)
            arr2d = data.reshape(bands, -1).T

            # Keep only pixels with build_up_index not NaN (domain mask)
            build_col = None
            for s in safe_names:
                if "build" in s.lower() and "index" in s.lower():
                    build_col = s
                    break
            if build_col is None:
                raise ValueError(f"Could not find build_up_index band in: {tif.name}")

            build_idx = safe_names.index(build_col)
            build_vals = arr2d[:, build_idx]

            keep_mask = mask.reshape(-1) & (~np.isnan(build_vals))
            if not keep_mask.any():
                continue

            # Subset pixels
            arr_keep = arr2d[keep_mask, :]
            df = pd.DataFrame(arr_keep, columns=safe_names)

            # Ensure fraction column exists exactly once
            frac_vals_from_band = df.iloc[:, frac_band0].astype(np.float32).to_numpy()
            df[frac_col_name] = frac_vals_from_band  # overwrite if already present

            # burned_pixel binary from fraction
            frac_vals = df[frac_col_name].to_numpy(dtype=np.float32, copy=False)
            burned_pixel = np.full(frac_vals.shape, np.nan, dtype=np.float32)
            valid_frac = ~np.isnan(frac_vals)
            burned_pixel[valid_frac & (frac_vals > PIXEL_BURN_THRESHOLD)] = 1.0
            burned_pixel[valid_frac & (frac_vals < PIXEL_BURN_THRESHOLD)] = 0.0
            df["burned_pixel"] = burned_pixel

            # Update global counters
            valid_lab = ~np.isnan(burned_pixel)
            if valid_lab.any():
                burned_total += int(np.sum(burned_pixel[valid_lab] == 1.0))
                unburned_total += int(np.sum(burned_pixel[valid_lab] == 0.0))
                valid_lab_total += int(valid_lab.sum())

            # Coordinates for kept pixels
            rows = np.arange(h)
            cols = np.arange(w)
            rr, cc = np.meshgrid(rows, cols, indexing="ij")
            xs, ys = rio.transform.xy(ds.transform, rr, cc, offset="center")
            xs = np.asarray(xs, dtype=np.float64).reshape(-1)[keep_mask]
            ys = np.asarray(ys, dtype=np.float64).reshape(-1)[keep_mask]
            lons, lats = build_lonlat(ds, xs, ys)

            df["longitude"] = lons
            df["latitude"] = lats
            df["year"] = year
            df["month"] = month

            # Canonical schema
            if canonical_cols is None:
                canonical_cols = list(safe_names)
                if frac_col_name not in canonical_cols:
                    canonical_cols.append(frac_col_name)
                for extra in ["burned_pixel", "longitude", "latitude", "year", "month"]:
                    if extra not in canonical_cols:
                        canonical_cols.append(extra)
                if len(canonical_cols) != len(set(canonical_cols)):
                    raise RuntimeError(f"Canonical cols not unique: {canonical_cols}")

            for col in canonical_cols:
                if col not in df.columns:
                    df[col] = np.nan

            df = df[canonical_cols]
            append_chunk_to_dataset(df, OUT_DATASET_DIR)

    print(f"\n✅ Done. Parquet dataset at:\n{OUT_DATASET_DIR}\n(partitioned by year=/month=)")

    # Global ratios
    print("\n=== Burned/Unburned pixel counts (filtered to burned_lab==1 1° cells) ===")
    print(f"Valid labeled pixels (fraction != NaN and != {PIXEL_BURN_THRESHOLD}): {valid_lab_total:,}")
    print(f"Burned pixels    (fraction > {PIXEL_BURN_THRESHOLD}): {burned_total:,}")
    print(f"Unburned pixels (fraction < {PIXEL_BURN_THRESHOLD}): {unburned_total:,}")

    if burned_total > 0:
        ratio = unburned_total / burned_total
        print(f"Unburned:Burned ratio = {ratio:.3f} : 1")
    else:
        print("Unburned:Burned ratio = inf (no burned pixels found)")

if __name__ == "__main__":
    main()

Scanning for analytical shapefiles in: /explore/nobackup/people/spotter5/clelland_fire_ml/training_e5l_cems_mcd_with_fraction/stage_1_predictions_on_mcd_analytical/pred_vs_obs_shapefiles_annual


Building partitioned Parquet dataset (burned_lab mask):   0%|          | 0/275 [00:00<?, ?it/s]


[YEAR 2001] burned_lab mask keeps 253,221 / 4,273,642 pixels (5.93%)


Building partitioned Parquet dataset (burned_lab mask):   4%|▍         | 12/275 [00:35<09:05,  2.07s/it]


[YEAR 2002] burned_lab mask keeps 397,886 / 4,273,642 pixels (9.31%)


Building partitioned Parquet dataset (burned_lab mask):   9%|▊         | 24/275 [01:23<15:49,  3.78s/it]


[YEAR 2003] burned_lab mask keeps 460,364 / 4,273,642 pixels (10.77%)


Building partitioned Parquet dataset (burned_lab mask):  13%|█▎        | 36/275 [01:47<07:52,  1.98s/it]


[YEAR 2004] burned_lab mask keeps 374,898 / 4,273,642 pixels (8.77%)


Building partitioned Parquet dataset (burned_lab mask):  17%|█▋        | 48/275 [02:29<14:00,  3.70s/it]


[YEAR 2005] burned_lab mask keeps 358,563 / 4,273,642 pixels (8.39%)


Building partitioned Parquet dataset (burned_lab mask):  22%|██▏       | 60/275 [02:57<06:32,  1.83s/it]


[YEAR 2006] burned_lab mask keeps 404,603 / 4,273,642 pixels (9.47%)


Building partitioned Parquet dataset (burned_lab mask):  26%|██▌       | 72/275 [03:32<12:07,  3.58s/it]


[YEAR 2007] burned_lab mask keeps 339,347 / 4,273,642 pixels (7.94%)


Building partitioned Parquet dataset (burned_lab mask):  31%|███       | 84/275 [04:06<06:21,  2.00s/it]


[YEAR 2008] burned_lab mask keeps 401,621 / 4,273,642 pixels (9.40%)


Building partitioned Parquet dataset (burned_lab mask):  35%|███▍      | 96/275 [04:42<10:44,  3.60s/it]


[YEAR 2009] burned_lab mask keeps 352,779 / 4,273,642 pixels (8.25%)


Building partitioned Parquet dataset (burned_lab mask):  39%|███▉      | 108/275 [05:17<05:34,  2.00s/it]


[YEAR 2010] burned_lab mask keeps 379,490 / 4,273,642 pixels (8.88%)


Building partitioned Parquet dataset (burned_lab mask):  44%|████▎     | 120/275 [05:49<08:50,  3.42s/it]


[YEAR 2011] burned_lab mask keeps 283,831 / 4,273,642 pixels (6.64%)


Building partitioned Parquet dataset (burned_lab mask):  48%|████▊     | 132/275 [06:27<05:15,  2.21s/it]


[YEAR 2012] burned_lab mask keeps 394,427 / 4,273,642 pixels (9.23%)


Building partitioned Parquet dataset (burned_lab mask):  52%|█████▏    | 144/275 [07:05<08:04,  3.70s/it]


[YEAR 2013] burned_lab mask keeps 242,740 / 4,273,642 pixels (5.68%)


Building partitioned Parquet dataset (burned_lab mask):  57%|█████▋    | 156/275 [07:37<03:46,  1.90s/it]


[YEAR 2014] burned_lab mask keeps 374,668 / 4,273,642 pixels (8.77%)


Building partitioned Parquet dataset (burned_lab mask):  61%|██████    | 168/275 [08:09<05:59,  3.36s/it]


[YEAR 2015] burned_lab mask keeps 359,512 / 4,273,642 pixels (8.41%)


Building partitioned Parquet dataset (burned_lab mask):  65%|██████▌   | 180/275 [08:47<03:32,  2.24s/it]


[YEAR 2016] burned_lab mask keeps 268,992 / 4,273,642 pixels (6.29%)


Building partitioned Parquet dataset (burned_lab mask):  70%|██████▉   | 192/275 [09:18<04:28,  3.23s/it]


[YEAR 2017] burned_lab mask keeps 341,895 / 4,273,642 pixels (8.00%)


Building partitioned Parquet dataset (burned_lab mask):  74%|███████▍  | 204/275 [10:01<03:10,  2.68s/it]


[YEAR 2018] burned_lab mask keeps 316,095 / 4,273,642 pixels (7.40%)


Building partitioned Parquet dataset (burned_lab mask):  79%|███████▊  | 216/275 [10:33<03:24,  3.46s/it]


[YEAR 2019] burned_lab mask keeps 325,901 / 4,273,642 pixels (7.63%)


Building partitioned Parquet dataset (burned_lab mask):  83%|████████▎ | 228/275 [11:13<01:50,  2.35s/it]


[YEAR 2020] burned_lab mask keeps 284,904 / 4,273,642 pixels (6.67%)


Building partitioned Parquet dataset (burned_lab mask):  87%|████████▋ | 240/275 [11:43<01:55,  3.30s/it]


[YEAR 2021] burned_lab mask keeps 330,361 / 4,273,642 pixels (7.73%)


Building partitioned Parquet dataset (burned_lab mask):  92%|█████████▏| 252/275 [12:23<00:52,  2.27s/it]


[YEAR 2022] burned_lab mask keeps 259,732 / 4,273,642 pixels (6.08%)


Building partitioned Parquet dataset (burned_lab mask):  96%|█████████▌| 264/275 [12:57<00:39,  3.55s/it]


[YEAR 2023] burned_lab mask keeps 327,195 / 4,273,642 pixels (7.66%)


Building partitioned Parquet dataset (burned_lab mask): 100%|██████████| 275/275 [13:32<00:00,  2.95s/it]


✅ Done. Parquet dataset at:
/explore/nobackup/people/spotter5/clelland_fire_ml/parquet_cems_mcd_with_fraction_dataset_burnedlab_mask_analytical
(partitioned by year=/month=)

=== Burned/Unburned pixel counts (filtered to burned_lab==1 1° cells) ===
Valid labeled pixels (fraction != NaN and != 0.5): 29,114,950
Burned pixels    (fraction > 0.5): 62,312
Unburned pixels (fraction < 0.5): 29,052,638
Unburned:Burned ratio = 466.245 : 1





Now apply stage 2 model to these files we just saved, and only in TP/FP stage 1 model cells

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Apply Stage-2 XGBoost model to MCD 4km pixels.
-- ANALYTICAL VERSION --

Logic:
1. Load "Analytical" Stage 1 predictions for MCD (Shapefiles).
2. Filter for coarse grid cells labeled 'TP' or 'FP'.
3. For every month in that year, load the 4km MCD GeoTIFF.
4. Mask 4km pixels: only predict on pixels INSIDE those TP/FP coarse cells.
5. Save the resulting probability map as a new GeoTIFF.
"""

import os
import re
import json
from pathlib import Path

import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio as rio
from rasterio import features
import xgboost as xgb

# ============================================================
# CONFIG
# ============================================================

# 1. INPUT SHAPEFILES (MCD Analytical Predictions from Stage 1)
SHP_DIR = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "training_e5l_cems_mcd_with_fraction/stage_1_predictions_on_mcd_analytical/pred_vs_obs_shapefiles_annual"
)

# 2. INPUT TIFFS (MCD Monthly Data)
IN_TIF_DIR = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "training_e5l_cems_mcd_with_fraction"
)

# 3. MODEL (Trained on FireCCI Stage 2)
MODEL_PATH = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "ml_training/xgb_pr_auc_filtered_burnedlabmask_native_analytical/models/xgb_best_pr_auc.json"
)

# 4. OUTPUT DIRECTORY
OUT_DIR = Path(
    "/explore/nobackup/people/spotter5/clelland_fire_ml/"
    "predictions_tp_fp_only_mcd_analytical"
)
OUT_DIR.mkdir(parents=True, exist_ok=True)

# 5. OVERWRITE SETTINGS
# Years listed here will be re-processed even if the output file exists.
OVERWRITE_YEARS = [2020, 2021, 2022, 2023]

# 15 Predictors used by the model (Must match training order)
FEATURES = [
    "DEM", "slope", "aspect", "b1", "relative_humidity", 
    "total_precipitation_sum", "temperature_2m", "temperature_2m_min", 
    "temperature_2m_max", "build_up_index", "drought_code", 
    "duff_moisture_code", "fine_fuel_moisture_code", 
    "fire_weather_index", "initial_fire_spread_index"
]

# ============================================================
# Helpers
# ============================================================

def get_annual_shp(year):
    path = SHP_DIR / f"cems_e5l_mcd_{year}_annual_grid1deg_pred_vs_obs_analytical.shp"
    return path if path.exists() else None

def get_monthly_tif(year, month):
    pattern = f"cems_e5l_mcd_{year}_{month}_with_fraction.tif"
    path = IN_TIF_DIR / pattern
    return path if path.exists() else None

# ============================================================
# Main Prediction Logic
# ============================================================

def main():
    if not MODEL_PATH.exists():
        raise FileNotFoundError(f"Model not found: {MODEL_PATH}")

    # Load XGBoost model
    print(f"Loading model: {MODEL_PATH}")
    booster = xgb.Booster()
    booster.load_model(str(MODEL_PATH))
    
    # Identify available years from shapefiles
    shp_files = list(SHP_DIR.glob("*_analytical.shp"))
    
    years = []
    for f in shp_files:
        m = re.search(r'cems_e5l_mcd_(\d{4})_', f.name)
        if m:
            years.append(int(m.group(1)))
    years = sorted(years)
    
    if not years:
        print(f"No matching analytical shapefiles found in {SHP_DIR}")
        return

    print(f"Found {len(years)} years to process: {years}")
    print(f"Years marked for overwrite: {OVERWRITE_YEARS}")
    
    for year in years:
        shp_path = get_annual_shp(year)
        if not shp_path:
            print(f"Skipping {year}: Expected file not found at {shp_path}")
            continue
        
        print(f"\n--- Starting Year: {year} ---")
        gdf = gpd.read_file(shp_path)
        
        if 'pred_obs' not in gdf.columns:
            print(f"Column 'pred_obs' missing in {shp_path.name}. Skipping.")
            continue

        mask_gdf = gdf[gdf['pred_obs'].str.upper().isin(['TP', 'FP'])].copy()
        
        if mask_gdf.empty:
            print(f"No TP/FP regions found for {year}. Skipping.")
            continue

        print(f"Processing {len(mask_gdf)} TP/FP coarse cells...")

        for month in range(1, 13):
            out_name = OUT_DIR / f"pred_tp_fp_mcd_{year}_{month:02d}.tif"
            
            # --- CHECK FOR EXISTING FILES & OVERWRITE LOGIC ---
            if out_name.exists():
                if year in OVERWRITE_YEARS:
                    print(f"  [OVERWRITE] File exists for {year}-{month:02d}, reprocessing as requested.")
                else:
                    print(f"  [SKIP] File exists for {year}-{month:02d}, skipping.")
                    continue
            # --------------------------------------------------

            tif_path = get_monthly_tif(year, month)
            if not tif_path:
                continue
            
            try:
                with rio.open(tif_path) as src:
                    # 1. Align CRS
                    if mask_gdf.crs != src.crs:
                        mask_gdf = mask_gdf.to_crs(src.crs)
                    
                    # 2. Rasterize TP/FP mask
                    mask = features.rasterize(
                        [(geom, 1) for geom in mask_gdf.geometry],
                        out_shape=src.shape,
                        transform=src.transform,
                        fill=0,
                        dtype='uint8'
                    )
                    
                    if not np.any(mask == 1):
                        continue

                    # 3. Read and Prepare Data
                    img_data = src.read()
                    
                    idx_y, idx_x = np.where(mask == 1)
                    pixels = img_data[:, idx_y, idx_x].T
                    
                    if pixels.shape[1] >= 15:
                        pixels = pixels[:, :15]
                    else:
                        print(f"Error: {tif_path.name} has only {pixels.shape[1]} bands (need 15+).")
                        continue

                    # 4. Predict
                    dmat = xgb.DMatrix(pixels, feature_names=FEATURES)
                    preds = booster.predict(dmat)
                    
                    # 5. Save
                    out_proba = np.zeros((src.height, src.width), dtype='float32')
                    out_proba[idx_y, idx_x] = preds
                    
                    out_meta = src.meta.copy()
                    out_meta.update(
                        dtype='float32', 
                        count=1, 
                        nodata=0,
                        compress='deflate'
                    )
                    
                    with rio.open(out_name, 'w', **out_meta) as dst:
                        dst.write(out_proba, 1)
                    
                    print(f"Saved: {out_name.name} (Predicted {len(preds):,} pixels)")

            except Exception as e:
                print(f"Failed to process {year}-{month:02d}: {e}")

    print("\n✅ All available years and months processed.")

if __name__ == "__main__":
    main()

Loading model: /explore/nobackup/people/spotter5/clelland_fire_ml/ml_training/xgb_pr_auc_filtered_burnedlabmask_native_analytical/models/xgb_best_pr_auc.json
Found 23 years to process: [2001, 2002, 2003, 2004, 2005, 2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020, 2021, 2022, 2023]
Years marked for overwrite: [2020, 2021, 2022, 2023]

--- Starting Year: 2001 ---
Processing 1968 TP/FP coarse cells...
Saved: pred_tp_fp_mcd_2001_01.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_02.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_03.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_04.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_05.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_06.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_07.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_08.tif (Predicted 766,801 pixels)
Saved: pred_tp_fp_mcd_2001_09.tif (Predicted 766,801 pixels)
Saved: 

In [4]:
't'

't'

In [None]:
Save burned area per year

In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Calculate Annual Burned Area (Mha) per Ecoregion for MCD Analytical Predictions.

Logic:
1. Loads monthly probability TIFFs for MCD (pred_tp_fp_mcd_YYYY_MM.tif).
2. Thresholds probabilities (> 0.80) to create binary monthly masks.
3. Aggregates monthly masks into a single Annual Burned Mask.
4. Intersects the Annual Mask with Ecoregion polygons.
5. Calculates area in Million Hectares (Mha), accounting for pixel area variation by latitude (EPSG:4326).
"""

import os
import re
from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio as rio
from rasterio.features import geometry_mask
from tqdm import tqdm

# ============================
# CONFIG
# ============================
YEARS  = list(range(2001, 2024)) # Adjust as needed
MONTHS = list(range(1, 13))

# UPDATED: Point to the MCD analytical predictions directory
PRED_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/predictions_tp_fp_only_mcd_analytical")
PROB_THRESHOLD = 0.50  # To convert probability to 0/1 mask

# Ecoregion shapefile
ECOS_PATH = "/explore/nobackup/people/spotter5/helene/raw/merge_eco_v2.shp"
ECO_ID_COL = "ecoregion"

# UPDATED: Output CSV directory (MCD analytical)
OUT_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical")
OUT_DIR.mkdir(parents=True, exist_ok=True)
OUT_CSV = OUT_DIR / "ba_ecoregion_tp_fp_predictions_08_mcd_analytical.csv"

# ============================
# HELPERS
# ============================

def get_annual_mask(year, pred_dir, threshold):
    """
    Aggregates monthly probability TIFFs into a single annual binary mask.
    Returns: (annual_mask_bool, transform, crs)
    """
    annual = None
    transform = None
    crs = None
    
    found_any = False

    for month in range(1, 13):
        # UPDATED Pattern: pred_tp_fp_mcd_YYYY_MM.tif
        tif_path = pred_dir / f"pred_tp_fp_mcd_{year}_{month:02d}.tif"
        if not tif_path.exists():
            continue
            
        found_any = True
        with rio.open(tif_path) as src:
            prob = src.read(1)
            # Binary mask: 1 if prob >= threshold, else 0
            monthly_burn = (prob >= threshold).astype(bool)
            
            if annual is None:
                annual = monthly_burn
                transform = src.transform
                crs = src.crs
            else:
                # Logical OR accumulation
                annual = annual | monthly_burn
                
    if not found_any:
        return None, None, None
                
    return annual, transform, crs

def get_pixel_area_grid(shape, transform, crs):
    """
    Generates a grid of pixel areas in Mha.
    Handles EPSG:4326 by calculating area based on latitude.
    """
    height, width = shape
    
    # Resolution
    res_x = abs(transform.a)
    res_y = abs(transform.e)

    if crs.is_geographic:
        # EPSG:4326 - Area depends on latitude
        # 1 degree lat ~ 111,320 meters
        # 1 degree lon ~ 111,320 * cos(lat) meters
        
        # Get latitude of every row center
        # transform * (0, row) gives top-left of row
        # We want center, so row + 0.5
        rows = np.arange(height) + 0.5
        _, lats = rio.transform.xy(transform, rows, np.zeros_like(rows), offset='center')
        lats = np.array(lats)
        
        # Calculate area per pixel for each row (in square meters)
        # Area = (res_x * 111320 * cos(lat)) * (res_y * 111320)
        
        lat_rads = np.radians(lats)
        # Cosine scaling for longitude width
        pixel_width_m = res_x * 111320 * np.cos(lat_rads)
        pixel_height_m = res_y * 111320
        
        row_areas_m2 = pixel_width_m * pixel_height_m
        
        # Broadcast to full grid (H, W)
        # Shape (H, 1) to broadcast across columns
        area_grid_m2 = row_areas_m2[:, np.newaxis] * np.ones((1, width))
        
    else:
        # Projected CRS (Meters) - Constant area
        pixel_area_m2 = res_x * res_y
        area_grid_m2 = np.full(shape, pixel_area_m2)

    # Convert m2 to Million Hectares (Mha)
    # 1 Ha = 10,000 m2
    # 1 Mha = 1,000,000 Ha = 10,000,000,000 m2 (1e10)
    area_grid_Mha = area_grid_m2 / 1e10
    
    return area_grid_Mha

# ============================
# MAIN
# ============================

print("Loading ecoregions...")
ecos = gpd.read_file(ECOS_PATH)

results = []

for year in YEARS:
    print(f"Processing Year: {year}")
    annual_mask, transform, crs = get_annual_mask(year, PRED_DIR, PROB_THRESHOLD)
    
    if annual_mask is None:
        print(f"  No predictions found for {year}, skipping.")
        continue

    # Reproject ecoregions to match raster if necessary
    if ecos.crs != crs:
        ecos_proj = ecos.to_crs(crs)
    else:
        ecos_proj = ecos

    # Generate Pixel Area Map (Mha) for this raster geometry
    pixel_area_map = get_pixel_area_grid(annual_mask.shape, transform, crs)

    height, width = annual_mask.shape
    
    # Iterate through ecoregions
    for idx, row in tqdm(ecos_proj.iterrows(), total=len(ecos_proj), desc=f"Ecoregions {year}"):
        eco_id = row[ECO_ID_COL]
        geom = row.geometry
        
        if geom is None or geom.is_empty:
            continue
            
        try:
            # Create boolean mask for this specific ecoregion
            # invert=True means True inside the shape
            eco_mask = geometry_mask(
                [geom],
                transform=transform,
                invert=True,
                out_shape=(height, width),
                all_touched=False # strict center inclusion
            )
            
            # Intersection: (Pixel is Burned) AND (Pixel is in Ecoregion)
            burned_in_eco_mask = annual_mask & eco_mask
            
            if burned_in_eco_mask.any():
                # Sum the specific areas of the burned pixels
                ba_Mha = pixel_area_map[burned_in_eco_mask].sum()
            else:
                ba_Mha = 0.0
            
            results.append({
                "ecoregion": eco_id,
                "year": year,
                "ba_pred_tp_fp_Mha": ba_Mha
            })
            
        except Exception as e:
            print(f"  Error processing eco {eco_id} in {year}: {e}")

# Save to CSV
if results:
    df_results = pd.DataFrame(results)
    df_results.to_csv(OUT_CSV, index=False)
    print(f"DONE. Results saved to {OUT_CSV}")
else:
    print("No results generated.")

Loading ecoregions...
Processing Year: 2001


Ecoregions 2001: 100%|██████████| 27/27 [00:05<00:00,  4.59it/s]


Processing Year: 2002


Ecoregions 2002: 100%|██████████| 27/27 [00:05<00:00,  4.61it/s]


Processing Year: 2003


Ecoregions 2003: 100%|██████████| 27/27 [00:05<00:00,  4.67it/s]


Processing Year: 2004


Ecoregions 2004: 100%|██████████| 27/27 [00:05<00:00,  4.63it/s]


Processing Year: 2005


Ecoregions 2005: 100%|██████████| 27/27 [00:06<00:00,  4.46it/s]


Processing Year: 2006


Ecoregions 2006: 100%|██████████| 27/27 [00:05<00:00,  4.50it/s]


Processing Year: 2007


Ecoregions 2007: 100%|██████████| 27/27 [00:05<00:00,  4.53it/s]


Processing Year: 2008


Ecoregions 2008: 100%|██████████| 27/27 [00:06<00:00,  4.07it/s]


Processing Year: 2009


Ecoregions 2009: 100%|██████████| 27/27 [00:05<00:00,  4.50it/s]


Processing Year: 2010


Ecoregions 2010: 100%|██████████| 27/27 [00:05<00:00,  4.52it/s]


Processing Year: 2011


Ecoregions 2011: 100%|██████████| 27/27 [00:12<00:00,  2.14it/s]


Processing Year: 2012


Ecoregions 2012: 100%|██████████| 27/27 [00:12<00:00,  2.23it/s]


Processing Year: 2013


Ecoregions 2013: 100%|██████████| 27/27 [00:12<00:00,  2.23it/s]


Processing Year: 2014


Ecoregions 2014: 100%|██████████| 27/27 [00:12<00:00,  2.22it/s]


Processing Year: 2015


Ecoregions 2015: 100%|██████████| 27/27 [00:12<00:00,  2.16it/s]


Processing Year: 2016


Ecoregions 2016: 100%|██████████| 27/27 [00:12<00:00,  2.19it/s]


Processing Year: 2017


Ecoregions 2017: 100%|██████████| 27/27 [00:12<00:00,  2.18it/s]


Processing Year: 2018


Ecoregions 2018: 100%|██████████| 27/27 [00:12<00:00,  2.18it/s]


Processing Year: 2019


Ecoregions 2019: 100%|██████████| 27/27 [00:12<00:00,  2.22it/s]


Processing Year: 2020


Ecoregions 2020: 100%|██████████| 27/27 [00:12<00:00,  2.19it/s]


Processing Year: 2021


Ecoregions 2021: 100%|██████████| 27/27 [00:12<00:00,  2.24it/s]


Processing Year: 2022


Ecoregions 2022: 100%|██████████| 27/27 [00:12<00:00,  2.21it/s]


Processing Year: 2023


Ecoregions 2023: 100%|██████████| 27/27 [00:12<00:00,  2.22it/s]

DONE. Results saved to /explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical/ba_ecoregion_tp_fp_predictions_08_mcd_analytical.csv





Make multipanel plot per ecoregion

In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Compare TP/FP predictions (MCD Analytical) to MCD64A1 and FireCCI native products (2001-2019).
Includes an ecoregion-summed "Total" panel and professional color palette.
-- MCD ANALYTICAL VERSION --
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# ============================
# CONFIG
# ============================

# Directory containing the REFERENCE data (MCD/FireCCI) - assumed to be in the original folder
REF_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries")

# Directory for NEW outputs (MCD Analytical)
OUT_DIR = Path("/explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# Files
BASE_CSV     = REF_DIR / "burned_area_by_ecoregion_predictions.csv"  # Contains ba_mcd_native_Mha / ba_firecci_native_Mha
NEW_PRED_CSV = OUT_DIR / "ba_ecoregion_tp_fp_predictions_08_mcd_analytical.csv" # From previous step

# Output Files
FINAL_CSV    = OUT_DIR / "burned_area_by_ecoregion_all_merged_08_mcd_analytical.csv"
OUT_PNG      = OUT_DIR / "burned_area_multipanel_tp_fp_comparison_08_mcd_analytical.png"

# Column Names
ECO_ID_COL  = "ecoregion"
MCD_COL     = "ba_mcd_native_Mha"
FIRECCI_COL = "ba_firecci_native_Mha"
PRED_COL    = "ba_pred_tp_fp_Mha"

YEAR_START, YEAR_END = 2001, 2024

# EXCLUSIONS
EXCLUDE_ECOS = {"WATER", "MIXED WOOD SHIELD", "TEMPERATE PRAIRIES", "WESTERN CORDILLERA"}

# PROFESSIONAL COLORS
COLORS = {
    MCD_COL: "#2c3e50",      # Slate Grey
    FIRECCI_COL: "#e67e22",  # Vivid Orange
    PRED_COL: "#16a085"      # Deep Teal (Predicted MCD TP+FP)
}

def nice_pred_label(colname: str) -> str:
    if colname == "ba_pred_tp_fp_Mha":
        return "Prediction (MCD TP+FP)"
    return colname

# ============================
# MAIN
# ============================

def main():
    print(f"Loading Base Reference Data from: {BASE_CSV}")
    if not BASE_CSV.exists():
        raise FileNotFoundError(f"Base CSV not found: {BASE_CSV}")
        
    print(f"Loading New MCD Predictions from: {NEW_PRED_CSV}")
    if not NEW_PRED_CSV.exists():
        raise FileNotFoundError(f"Prediction CSV not found: {NEW_PRED_CSV}")

    # --- 1. Load, Filter and Merge ---
    df_base = pd.read_csv(BASE_CSV)
    df_pred = pd.read_csv(NEW_PRED_CSV)
    
    # Merge on Ecoregion and Year
    # We left join pred onto base to keep the structure.
    df = df_base.merge(df_pred, on=[ECO_ID_COL, "year"], how="left")
    
    # Filter years 2001-2019
    df = df[(df["year"] >= YEAR_START) & (df["year"] <= YEAR_END)].copy()
    
    # Save merged data for inspection
    df.to_csv(FINAL_CSV, index=False)
    print(f"Merged CSV saved to: {FINAL_CSV}")

    # --- 2. Prepare Subplots ---
    ecos_all = sorted(df[ECO_ID_COL].dropna().unique())
    ecos_list = [e for e in ecos_all if e not in EXCLUDE_ECOS]
    
    # Add a virtual "TOTAL" entry to the list
    plot_list = ecos_list + ["TOTAL BURNED AREA"]
    n_panels = len(plot_list)

    ncols = 4
    nrows = int(np.ceil(n_panels / ncols))

    fig, axes = plt.subplots(
        nrows=nrows, ncols=ncols, 
        figsize=(4 * ncols, 3.5 * nrows), 
        sharex=True
    )
    axes = axes.flatten()

    handles_for_legend = None

    # --- 3. Plotting Loop ---
    for i, title in enumerate(plot_list):
        ax = axes[i]
        
        if title == "TOTAL BURNED AREA":
            # Aggregate sum across all ecoregions
            df_plot = df.groupby("year")[[MCD_COL, FIRECCI_COL, PRED_COL]].sum().reset_index()
            ax.set_facecolor('#fdfefe') # Light highlight for total panel
        else:
            df_plot = df[df[ECO_ID_COL] == title].sort_values("year")

        # Plot datasets
        p1, = ax.plot(df_plot["year"], df_plot[MCD_COL], marker="o", markersize=4, 
                      label="MCD64A1", color=COLORS[MCD_COL], linewidth=1.5)
        p2, = ax.plot(df_plot["year"], df_plot[FIRECCI_COL], marker="s", markersize=4, 
                      label="Fire CCI", color=COLORS[FIRECCI_COL], linewidth=1.5)
        p3, = ax.plot(df_plot["year"], df_plot[PRED_COL], marker="^", markersize=4, 
                      label=nice_pred_label(PRED_COL), color=COLORS[PRED_COL], linewidth=2)

        ax.set_title(str(title), fontsize=11, fontweight='bold')
        ax.grid(True, ls=":", alpha=0.6)
        ax.tick_params(axis='both', labelsize=9)

        if i == 0:
            handles_for_legend = [p1, p2, p3]

        # Axis labeling
        if i >= (n_panels - ncols):
            ax.set_xlabel("Year", fontsize=10)
        if i % ncols == 0:
            ax.set_ylabel("Burned Area (Mha)", fontsize=10)

        # Handle scaling for very low values (avoid scientific notation or flat lines at 0)
        # Check max of columns ignoring NaNs
        vals = df_plot[[MCD_COL, FIRECCI_COL, PRED_COL]]
        y_max = vals.max().max() if not vals.empty else 0
        
        if pd.notna(y_max) and y_max < 0.005:
            ax.set_ylim(0, 0.01)

    # Clean up empty subplots
    for j in range(i + 1, len(axes)):
        axes[j].axis("off")

    # Global legend
    if handles_for_legend:
        fig.legend(
            handles=handles_for_legend,
            labels=["MCD64A1", "Fire CCI", nice_pred_label(PRED_COL)],
            loc="lower center", 
            ncol=3, 
            fontsize=12,
            frameon=False,
            bbox_to_anchor=(0.5, -0.02)
        )

    plt.tight_layout(rect=[0, 0.03, 1, 0.97])
    plt.savefig(OUT_PNG, dpi=250, bbox_inches="tight")
    plt.close()

    print(f"✅ Comparison plot (2001-2019) with TOTAL panel saved to:\n   {OUT_PNG}")

if __name__ == "__main__":
    main()

Loading Base Reference Data from: /explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries/burned_area_by_ecoregion_predictions.csv
Loading New MCD Predictions from: /explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical/ba_ecoregion_tp_fp_predictions_08_mcd_analytical.csv
Merged CSV saved to: /explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical/burned_area_by_ecoregion_all_merged_08_mcd_analytical.csv
✅ Comparison plot (2001-2019) with TOTAL panel saved to:
   /explore/nobackup/people/spotter5/clelland_fire_ml/burned_area_summaries_mcd_analytical/burned_area_multipanel_tp_fp_comparison_08_mcd_analytical.png


In [9]:
't'

't'