In [None]:
import os, warnings
from pathlib import Path
import re, uuid
import numpy as np
import pandas as pd
import rasterio
from dbfread import DBF
import pyarrow as pa
import pyarrow.parquet as pq

warnings.filterwarnings("ignore", message="Can't initialize NVML")


os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("NUMEXPR_NUM_THREADS", "1")

# ---- config ----s
DATASET_ROOT = Path("../../Dataset")
DISTURBANCES_DATASET = DATASET_ROOT / "AnnualDisturbance_1999_present"
TREATMENTS_OUTPUT = Path("../../Outputs/treatments")
TILE = 4096
BAND = 1
TREATMENT_NAME = "Thinning"
WORKERS = min(12, os.cpu_count() - 4)
FLUSH_ROWS = 200_000
GPU_MIN_ELEMS = 2_000_000
TREATMENTS_OUTPUT.mkdir(parents=True, exist_ok=True)

def find_tifs(base: Path):
    return sorted(base.rglob("Tif/*.tif"))

def year_from_path(path: Path):
    year_regex_match = re.search(r"(?:US_DIST|LF)\D*?(\d{4})", str(path))
    if year_regex_match: return int(year_regex_match.group(1))
    try:
        with rasterio.open(path) as src:
            y = src.tags().get("DIST_YEAR") or src.tags().get("Year")
            return int(y) if y else None
    except Exception:
        return None

def find_vat_for_tif(tif_path: Path):
    for c in tif_path.parent.glob(f"{tif_path.name}.vat.dbf"): return c
    anyvat = list(tif_path.parent.glob("*.vat.dbf"))
    if anyvat: return anyvat[0]
    csv_dir = tif_path.parents[1] / "CSV_Data"
    if csv_dir.exists():
        csvs = sorted(csv_dir.glob("*.csv"))
        if csvs: return csvs[0]
    return None

def load_code_table(vat_path: Path):
    if vat_path.suffix.lower() == ".dbf":
        df = pd.DataFrame(DBF(str(vat_path), load=True))
    elif vat_path.suffix.lower() == ".csv":
        df = pd.read_csv(vat_path)
    else:
        raise FileNotFoundError(vat_path)

    df.columns = [c.strip().upper() for c in df.columns]

    for k in ("VALUE","GRIDCODE","CODE","VALUE_"):
        if k in df.columns:
            code_col = k; break
    else:
        raise KeyError(f"No code col in {vat_path}")

    for tcol in ("DIST_TYPE","DISTTYPE","DIST TYPE","DIST_TYPE1","DISTTYPE1",
                 "DIST_TYPE2","DISTTYPE2","DIST_TYPE3","DIST_TYPE4",
                 "DIST_TYPE_","DISTTYPE_"):
        if tcol in df.columns:
            type_col = tcol; break
    else:
        cand = [c for c in df.columns if "TYPE" in c and "DIST" in c]
        if not cand: raise KeyError(f"No type col in {vat_path}")
        type_col = cand[0]

    out = df[[code_col, type_col]].copy()
    out.columns = ["CODE","DIST_TYPE"]
    out["CODE"] = pd.to_numeric(out["CODE"], errors="coerce").astype("Int64")
    out["DIST_TYPE"] = out["DIST_TYPE"].astype(str).str.strip().str.lower()  # <-- fixed
    return out.dropna(subset=["CODE"])


def thinning_codes(vat_df: pd.DataFrame, treat_name=TREATMENT_NAME):
    t=treat_name.strip().lower()
    return set(vat_df.loc[vat_df["DIST_TYPE"]==t,"CODE"].dropna().astype(int).unique().tolist())

def ensure_part_dir(root: Path, year: int, ti: int, tj: int)->Path:
    d = root / f"yr={year}" / f"tile_i={ti}" / f"tile_j={tj}"
    d.mkdir(parents=True, exist_ok=True)
    return d

def write_part_fast(df: pd.DataFrame, out_dir: Path):
    fn = out_dir / f"part-{uuid.uuid4().hex[:8]}.parquet"
    pq.write_table(pa.Table.from_pandas(df, preserve_index=False),
                   fn, compression="zstd", use_dictionary=True)


In [5]:
# ---- CPU-only worker ----
def process_one_tif_cpu(tif_path: str, out_root: str, tile_size: int, gpu_min_elems: int=0):
    from pathlib import Path
    import numpy as np, pandas as pd, rasterio, uuid
    tif_path = Path(tif_path); out_root = Path(out_root)

    year = year_from_path(tif_path)
    if not year: return (tif_path.name, 0, 0)
    vat_path = find_vat_for_tif(tif_path)
    if not vat_path: return (tif_path.name, 0, 0)

    vat_df = load_code_table(vat_path)
    codes = thinning_codes(vat_df)
    if not codes: return (tif_path.name, 0, 0)

    code_arr = np.array(sorted(list(codes)))
    buffers, kept, written = {}, 0, 0

    with rasterio.open(tif_path) as src:
        nodata = src.nodata
        for _, w in src.block_windows(1):
            a = src.read(1, window=w)
            m = (a != nodata) if nodata is not None else np.ones_like(a, bool)
            if not m.any(): continue
            sel = m & np.isin(a, code_arr)
            if not sel.any(): continue

            rr, cc = np.nonzero(sel)
            rows, cols = rr + int(w.row_off), cc + int(w.col_off)
            kept += rows.size

            tile_i = rows // tile_size;  tile_j = cols // tile_size
            r_in = rows % tile_size;     c_in = cols % tile_size

            df = pd.DataFrame({
                "year": np.full(rows.shape[0], year, dtype=np.int16),
                "tile_i": tile_i.astype(np.int32),
                "tile_j": tile_j.astype(np.int32),
                "row": rows.astype(np.int32),
                "col": cols.astype(np.int32),
                "r": r_in.astype(np.int16),
                "c": c_in.astype(np.int16),
                "type_code": np.full(rows.shape[0], 1, dtype=np.int16),
            })

            for (ti, tj), g in df.groupby(["tile_i","tile_j"], sort=False):
                key = (int(ti), int(tj))
                lst = buffers.get(key, [])
                lst.append(g); buffers[key] = lst
                if sum(x.shape[0] for x in lst) >= FLUSH_ROWS:
                    out_dir = ensure_part_dir(Path(out_root), year, key[0], key[1])
                    write_part_fast(pd.concat(lst, ignore_index=True), out_dir)
                    buffers[key] = []; written += 1

    for (ti, tj), lst in buffers.items():
        if lst:
            out_dir = ensure_part_dir(Path(out_root), year, int(ti), int(tj))
            write_part_fast(pd.concat(lst, ignore_index=True), out_dir)
            written += 1

    return (tif_path.name, kept, written)


In [None]:
# ---- Driver using fork (works in notebooks) ----
from concurrent.futures import ProcessPoolExecutor, as_completed
import multiprocessing as mp, os

# do NOT import torch anywhere in this notebook now
os.environ.setdefault("OMP_NUM_THREADS","1"); os.environ.setdefault("MKL_NUM_THREADS","1")

ctx = mp.get_context("fork")  # safe with rasterio, no CUDA needed
tifs = find_tifs(DISTURBANCES_DATASET)
print(f"files={len(tifs)}  workers={WORKERS}")

stats = []
with ProcessPoolExecutor(max_workers=WORKERS, mp_context=ctx) as ex:
    futs = {ex.submit(process_one_tif_cpu, str(t), str(TREATMENTS_OUTPUT), TILE): t for t in tifs}
    for f in as_completed(futs):
        try:
            stats.append(f.result())
        except Exception as e:
            print("fail:", futs[f], e)

print("processed:", len(stats))
print("kept_pixels:", sum(k for _, k, _ in stats))
print("parquet_parts:", sum(w for _, _, w in stats))


files=26  workers=12


processed: 26
kept_pixels: 40490314
parquet_parts: 4355
