In [None]:
# Performance guards
import os, warnings
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")
warnings.filterwarnings("ignore", message="Can't initialize NVML")

from pathlib import Path
import re, uuid
import numpy as np
import pandas as pd
import rasterio
from rasterio.windows import Window
from rasterio.features import rasterize
from shapely.geometry import mapping
import pyarrow as pa
import pyarrow.parquet as pq
import pyarrow.dataset as ds

# ---------------- config ----------------
DATA_ROOT = Path("../../Dataset")
DIST_ROOT = DATA_ROOT / "AnnualDisturbance_1999_present"
EVENTS_ROOT = Path("../../Outputs") / "treatments"
OUT_ROOT = Path("../../Outputs") / "model_database"
FIRE_GDB = DATA_ROOT / "S_USA.MTBS_BURN_AREA_BOUNDARY.gdb"

LOOKBACK_YEARS = 10
BUFFER_METERS = 1609            # ~1 mile
TILE_SIZE = 4096                # must match step 1 output
WRITE_BUFFER_ROWS = 200_000     # rows per file before flushing
# ---------------------------------------

OUT_ROOT.mkdir(parents=True, exist_ok=True)

# ---------- reference grid ----------
def find_reference_tif(dist_root: Path) -> Path:
    cands = sorted(dist_root.rglob("Tif/*.tif"))
    if not cands:
        raise FileNotFoundError("No disturbance GeoTIFFs found under Dataset/AnnualDisturbance_1999_present")
    return cands[0]

def open_grid(ref_tif: Path):
    return rasterio.open(ref_tif)  # caller closes

# ---------- tiling helpers ----------
def tiles_covering_bbox(r0, c0, r1, c1, tile_size):
    ti0 = max(0, r0 // tile_size); tj0 = max(0, c0 // tile_size)
    ti1 = r1 // tile_size;         tj1 = c1 // tile_size
    for ti in range(ti0, ti1 + 1):
        for tj in range(tj0, tj1 + 1):
            yield (ti, tj)

def window_from_tile(ti, tj, H, W, tile_size):
    row_off = ti * tile_size
    col_off = tj * tile_size
    height  = min(tile_size, H - row_off)
    width   = min(tile_size, W - col_off)
    return Window(col_off=col_off, row_off=row_off, width=width, height=height)

# ---------- events (step-1) lookup ----------
def _encode_keys(rows: np.ndarray, cols: np.ndarray) -> np.ndarray:
    """Combine (row, col) into uint64 keys for fast membership."""
    return (rows.astype(np.uint64) << 32) | cols.astype(np.uint64)

def load_event_keys_for_partition(events_root: Path, year: int, tile_i: int, tile_j: int) -> np.ndarray:
    """Return uint64 keys for pixels treated in given (year, tile_i, tile_j)."""
    part_dir = events_root / f"yr={year}" / f"tile_i={tile_i}" / f"tile_j={tile_j}"
    if not part_dir.exists():
        return np.empty(0, dtype=np.uint64)
    dset = ds.dataset(part_dir.as_posix(), format="parquet")
    tbl = dset.to_table(columns=["row", "col"])
    if tbl.num_rows == 0:
        return np.empty(0, dtype=np.uint64)
    rows = tbl.column("row").to_numpy()
    cols = tbl.column("col").to_numpy()
    return _encode_keys(rows, cols)

def write_panel_chunk(fire_id: str, frame: pd.DataFrame):
    out_dir = OUT_ROOT / f"fire_id={fire_id}"
    out_dir.mkdir(parents=True, exist_ok=True)
    fn = out_dir / f"part-{uuid.uuid4().hex[:8]}.parquet"
    pq.write_table(pa.Table.from_pandas(frame, preserve_index=False),
                   fn, compression="zstd", use_dbictionary=True)


In [9]:
from functools import lru_cache
from shapely import wkb as _wkb

def build_panel_for_fire(fire_id: str, fire_year: int, geom_wkb: bytes, ref_tif_path: str):
    """Compute labels for a single fire polygon with lookback years, write Parquet parts."""
    # open grid inside worker
    with rasterio.open(ref_tif_path) as src:
        transform = src.transform
        H, W = src.height, src.width

        geom = _wkb.loads(geom_wkb)
        if geom is None or geom.is_empty:
            return (fire_id, 0, 0)

        # buffer in projected meters
        buffered = geom.buffer(BUFFER_METERS)

        # pixel bbox
        inv = ~transform
        minx, miny, maxx, maxy = buffered.bounds
        c0, r1 = inv * (minx, miny)
        c1, r0 = inv * (maxx, maxy)
        r0 = int(np.floor(max(0, min(r0, H - 1))))
        r1 = int(np.floor(max(0, min(r1, H - 1))))
        c0 = int(np.floor(max(0, min(c0, W - 1))))
        c1 = int(np.floor(max(0, min(c1, W - 1))))

        total_rows_written = 0
        parts_written = 0
        write_buffer = []  # list[pd.DataFrame]

        @lru_cache(maxsize=10000)
        def cached_event_keys(target_year: int, tile_i: int, tile_j: int) -> np.ndarray:
            return load_event_keys_for_partition(EVENTS_ROOT, target_year, tile_i, tile_j)

        for tile_i, tile_j in tiles_covering_bbox(r0, c0, r1, c1, TILE_SIZE):
            win = window_from_tile(tile_i, tile_j, H, W, TILE_SIZE)
            if win.width <= 0 or win.height <= 0:
                continue

            # rasterize buffered polygon into this tile
            w_transform = rasterio.windows.transform(win, transform)
            mask = rasterize(
                [(mapping(buffered), 1)],
                out_shape=(int(win.height), int(win.width)),
                transform=w_transform,
                fill=0, dtype="uint8", all_touched=False,
            )
            if mask.max() == 0:
                continue

            rr, cc = np.nonzero(mask == 1)
            if rr.size == 0:
                continue

            rows = rr + int(win.row_off)
            cols = cc + int(win.col_off)
            pixel_keys = _encode_keys(rows, cols)

            base_cols = {
                "fire_id": np.full(rows.size, fire_id),
                "row": rows.astype(np.int32),
                "col": cols.astype(np.int32),
                "tile_i": np.full(rows.size, tile_i, dtype=np.int32),
                "tile_j": np.full(rows.size, tile_j, dtype=np.int32),
            }

            # t in [fire_year - LOOKBACK_YEARS, fire_year - 1]
            for t in range(fire_year - LOOKBACK_YEARS, fire_year):
                ev_keys = cached_event_keys(t + 1, tile_i, tile_j)  # “treated next year”
                treated = np.isin(pixel_keys, ev_keys, assume_unique=False).astype(np.int8)

                frame = pd.DataFrame({
                    **base_cols,
                    "year_t": np.full(rows.size, t, dtype=np.int16),
                    "treated_next_year": treated,
                })
                write_buffer.append(frame)
                total_rows_written += frame.shape[0]

                if total_rows_written >= WRITE_BUFFER_ROWS:
                    write_panel_chunk(fire_id, pd.concat(write_buffer, ignore_index=True))
                    write_buffer.clear()
                    parts_written += 1

        if write_buffer:
            write_panel_chunk(fire_id, pd.concat(write_buffer, ignore_index=True))
            parts_written += 1

        return (fire_id, total_rows_written, parts_written)


In [11]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import geopandas as gpd, fiona
import multiprocessing as mp

# read fires once, project to raster CRS, and serialize geometry to WKB
ref_tif = find_reference_tif(DIST_ROOT)
with open_grid(ref_tif) as src:
    raster_crs = src.crs

layers = fiona.listlayers(FIRE_GDB.as_posix())
layer_name = next((L for L in layers if "BURN" in L.upper() and "BOUND" in L.upper()), layers[0])

fires_gdf = gpd.read_file(FIRE_GDB.as_posix(), layer=layer_name).to_crs(raster_crs)
geom_col = fires_gdf.geometry.name
cols_upper = {c.upper(): c for c in fires_gdf.columns}
year_col = next(c for k in ("YEAR","FIRE_YEAR","YR") if (c := cols_upper.get(k)))
id_col   = next(c for k in ("FIRE_ID","IRWINID","MAP_ID","OBJECTID") if (c := cols_upper.get(k)))

fires_gdf = fires_gdf[fires_gdf[geom_col].notnull()].copy()
fires_gdf["FIRE_ID"] = fires_gdf[id_col].astype(str)
fires_gdf["YEAR"]    = fires_gdf[year_col].astype(int)
fires_gdf["GEOM_WKB"] = fires_gdf.geometry.to_wkb()

fire_records = list(fires_gdf[["FIRE_ID","YEAR","GEOM_WKB"]].itertuples(index=False, name=None))
print(f"fires={len(fire_records)}")

# parallel CPU using fork (works in notebooks)
ctx = mp.get_context("fork")
WORKERS = max(1, min(12, (os.cpu_count() or 8) - 4))
print(f"workers={WORKERS}")

stats = []
with ProcessPoolExecutor(max_workers=WORKERS, mp_context=ctx) as ex:
    futs = {
        ex.submit(build_panel_for_fire, fid, yr, wkb, str(ref_tif)): fid
        for (fid, yr, wkb) in fire_records
    }
    for f in as_completed(futs):
        fid = futs[f]
        try:
            stats.append(f.result())
        except Exception as e:
            print("fail:", fid, e)

print("processed_fires:", len(stats))
print("rows_written:", sum(r for _, r, _ in stats))
print("parquet_parts:", sum(p for _, _, p in stats))


  return ogr_read(


fires=30730
workers=12
processed_fires: 30730
rows_written: 21539339900
parquet_parts: 197006
