In [None]:
# build_panel_labels.py
# Create pixel-year labels per fire from events parquet (thinning) without covariates.

import os, re, uuid, warnings
from pathlib import Path
import numpy as np
import pandas as pd
import geopandas as gpd
import rasterio
from rasterio.features import rasterize
from shapely.geometry import mapping
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds
import fiona

# ---------------- config ----------------
DATA_ROOT = Path("../../Dataset")
DIST_ROOT = DATA_ROOT / "AnnualDisturbance_1999_present"
EVENTS_ROOT = Path("../../Outputs") / "treatments"
OUT_ROOT = Path("../../Outputs") / "model_database"
FIRE_GDB = DATA_ROOT / "S_USA.MTBS_BURN_AREA_BOUNDARY.gdb"

LOOKBACK_YEARS = 10
BUFFER_METERS = 1609            # ~1 mile
TILE = 4096                     # must match step 1
BAND = 1
# ---------------------------------------

OUT_ROOT.mkdir(parents=True, exist_ok=True)

def find_reference_tif(dist_root: Path) -> Path:
    cands = sorted(dist_root.rglob("Tif/*.tif"))
    if not cands:
        raise FileNotFoundError("No disturbance GeoTIFFs found under Dataset/AnnualDisturbance_1999_present")
    return cands[0]

def open_grid(ref_tif: Path):
    src = rasterio.open(ref_tif)
    return src  # caller closes

def list_fire_layers(gdb_path: Path):
    return fiona.listlayers(gdb_path)

def load_fires(gdb_path: Path, raster_crs):
    import fiona, geopandas as gpd

    # pick a layer
    layers = fiona.listlayers(gdb_path.as_posix())
    lname = next((L for L in layers if "BURN" in L.upper() and "BOUND" in L.upper()), layers[0])

    gdf = gpd.read_file(gdb_path.as_posix(), layer=lname)

    # keep geometry column name as-is; uppercase the rest
    geom_col = gdf.geometry.name  # e.g., 'geometry' or 'GEOMETRY'
    rename = {c: c.upper() for c in gdf.columns if c != geom_col}
    gdf = gdf.rename(columns=rename, copy=False)

    # ensure active geometry set correctly if name changed upstream
    if gdf.geometry.name != geom_col and geom_col in gdf.columns:
        gdf = gdf.set_geometry(geom_col)
    if gdf.geometry.name is None and "GEOMETRY" in gdf.columns:
        gdf = gdf.set_geometry("GEOMETRY")

    # select year/id columns case-insensitively (don’t touch geometry)
    cols_up = {c.upper(): c for c in gdf.columns}
    year_col = next(c for k in ("YEAR","FIRE_YEAR","YR") if (c := cols_up.get(k)))
    id_col   = next(c for k in ("FIRE_ID","IRWINID","MAP_ID","OBJECTID") if (c := cols_up.get(k)))

    gdf = gdf[gdf[geom_col].notnull()].copy()
    gdf = gdf.to_crs(raster_crs)

    gdf["YEAR"] = gdf[year_col].astype(int)
    gdf["FIRE_ID"] = gdf[id_col].astype(str)
    return gdf[["FIRE_ID","YEAR",geom_col]].set_geometry(geom_col)


def tiles_touching_bounds(row0, col0, row1, col1, tile_size):
    ti0 = max(0, row0 // tile_size); tj0 = max(0, col0 // tile_size)
    ti1 = row1 // tile_size;         tj1 = col1 // tile_size
    for ti in range(ti0, ti1 + 1):
        for tj in range(tj0, tj1 + 1):
            yield (ti, tj)

def window_for_tile(ti, tj, H, W, tile_size):
    r0 = ti * tile_size
    c0 = tj * tile_size
    h  = min(tile_size, H - r0)
    w  = min(tile_size, W - c0)
    return rasterio.windows.Window(col_off=c0, row_off=r0, width=w, height=h)

def events_for_year_tile(events_root: Path, year: int, ti: int, tj: int):
    # Read exactly one partition dir if it exists; return set of (row,col)
    part_dir = events_root / f"yr={year}" / f"tile_i={ti}" / f"tile_j={tj}"
    if not part_dir.exists():
        return set()
    dset = ds.dataset(part_dir.as_posix(), format="parquet")
    tbl = dset.to_table(columns=["row","col"])  # row groups auto-merged
    if tbl.num_rows == 0:
        return set()
    arr_r = tbl.column("row").to_numpy()
    arr_c = tbl.column("col").to_numpy()
    return set(zip(arr_r.tolist(), arr_c.tolist()))

def write_panel_chunk(fire_id: str, df: pd.DataFrame):
    out_dir = OUT_ROOT / f"fire_id={fire_id}"
    out_dir.mkdir(parents=True, exist_ok=True)
    fn = out_dir / f"part-{uuid.uuid4().hex[:8]}.parquet"
    table = pa.Table.from_pandas(df, preserve_index=False)
    pq.write_table(table, fn, compression="zstd", use_dictionary=True)

def process_fire(feature, src):
    fire_id = feature["FIRE_ID"]
    Y = int(feature["YEAR"])
    geom = feature.geometry
    if geom is None or geom.is_empty:
        return

    # buffer in raster units (meters in LANDFIRE Albers)
    geom_buf = geom.buffer(BUFFER_METERS)

    H, W = src.height, src.width
    T = src.transform
    Tinv = ~T

    # pixel bounds of the buffered fire
    minx, miny, maxx, maxy = geom_buf.bounds
    c0, r1 = Tinv * (minx, miny)  # note y→row inversion
    c1, r0 = Tinv * (maxx, maxy)
    r0 = int(np.floor(max(0, min(r0, H-1))))
    r1 = int(np.floor(max(0, min(r1, H-1))))
    c0 = int(np.floor(max(0, min(c0, W-1))))
    c1 = int(np.floor(max(0, min(c1, W-1))))

    # iterate tiles touching bbox
    for (ti, tj) in tiles_touching_bounds(r0, c0, r1, c1, TILE):
        win = window_for_tile(ti, tj, H, W, TILE)
        if win.width <= 0 or win.height <= 0: 
            continue

        # rasterize fire mask into this window
        w_transform = rasterio.windows.transform(win, T)
        mask = rasterize(
            [(mapping(geom_buf), 1)],
            out_shape=(int(win.height), int(win.width)),
            transform=w_transform,
            fill=0,
            dtype="uint8",
            all_touched=False,
        )
        if mask.max() == 0:
            continue

        rr, cc = np.nonzero(mask == 1)
        if rr.size == 0:
            continue

        rows = rr + int(win.row_off)
        cols = cc + int(win.col_off)
        tile_i = np.full(rows.size, ti, dtype=np.int32)
        tile_j = np.full(rows.size, tj, dtype=np.int32)

        # build labels for each t in [Y-10, Y-1]
        for t in range(Y - LOOKBACK_YEARS, Y):
            # label uses events at year = t+1
            ev_set = events_for_year_tile(EVENTS_ROOT, t + 1, ti, tj)
            if ev_set:
                pairs = set(zip(rows.tolist(), cols.tolist()))
                treated = np.fromiter((1 if (r, c) in ev_set else 0 for r, c in pairs),
                                      dtype=np.int8, count=len(pairs))
            else:
                treated = np.zeros(rows.size, dtype=np.int8)

            df = pd.DataFrame({
                "fire_id": fire_id,
                "year_t": np.full(rows.size, t, dtype=np.int16),
                "row": rows.astype(np.int32),
                "col": cols.astype(np.int32),
                "tile_i": tile_i,
                "tile_j": tile_j,
                "treated_next_year": treated
            })
            write_panel_chunk(fire_id, df)

def main():
    ref_tif = find_reference_tif(DIST_ROOT)
    with open_grid(ref_tif) as src:
        fires = load_fires(FIRE_GDB, src.crs)
        # iterate fires
        for _, feat in fires.iterrows():
            try:
                process_fire(feat, src)
                print(f"done: {feat['FIRE_ID']} ({feat['YEAR']})")
            except Exception as e:
                warnings.warn(f"fire {feat['FIRE_ID']} failed: {e}")

if __name__ == "__main__":
    main()


done: AR3453309385820240223 (2024)
done: WY4435410557920240822 (2024)
done: UT3996611396320230721 (2023)
done: AR3449309427120240223 (2024)
done: AR3483909390820240223 (2024)
done: AZ3338011031820240727 (2024)
done: AR3487509395420240222 (2024)
done: AZ3344911027020240721 (2024)
done: OR4403712030520240902 (2024)
done: UT3984911337020230412 (2023)
done: CA3389711605420230610 (2023)
done: AR3476209371020240220 (2024)
done: AZ3343311026120240726 (2024)
done: AZ3346711020320240720 (2024)
done: AK6412114167320230807 (2023)
done: AR3437509350720240319 (2024)
done: AK6411514113120230803 (2023)
done: WY4288410636420240911 (2024)
done: AR3498009395320240318 (2024)
done: AR3441209431320240312 (2024)
done: AR3458509363520240312 (2024)
done: CA3614212061820231009 (2023)
done: OR4462111740120240723 (2024)
done: AR3486009389020240312 (2024)
done: AR3572709377920240318 (2024)
done: OR4242612068520230918 (2023)
done: AR3573509298920240317 (2024)
done: NV4077811973920240724 (2024)
done: AR359700922012