In [None]:
import os
import glob
import random
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import rasterio
from rasterio.transform import Affine

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm


# =========================
# CONFIG
# =========================
BASE_DIR = r"G:\Hangkai\Global_Forest_edge_mapping_data"

# 0.01° dynamics rasters
DYN_DIR = os.path.join(BASE_DIR, "Global_001_degree")
DYN_FILES = {
    "id":     os.path.join(DYN_DIR, "id.tif"),
    "ii":     os.path.join(DYN_DIR, "ii.tif"),
    "dd":     os.path.join(DYN_DIR, "dd.tif"),
    "di":     os.path.join(DYN_DIR, "di.tif"),
    "stable": os.path.join(DYN_DIR, "stable.tif"),
}
CLASSES = ["id", "ii", "dd", "di", "stable"]

# High-res folders (0.00025°)
YEARS = [2000, 2005, 2010, 2015, 2020]
AREA_DIR_TMPL = os.path.join(BASE_DIR, "{year}Area")
EDGE_DIR_TMPL = os.path.join(BASE_DIR, "{year}Edge")

# Sampling from 0.01° pixels
N_PER_CLASS = 500
SEED = 7
MIN_WINNER_VALUE = 0.0        # set >0 to exclude near-zero dynamics
MIN_SEPARATION_DEG = 0.20     # within-class minimum spacing (~22 km in latitude)
MAX_TRIES_PER_CLASS = 2_000_000

# Output structure (per-sample folder)
OUT_DIR = os.path.join(BASE_DIR, "validation_samples_sample_folder")
OUT_CSV = os.path.join(OUT_DIR, f"samples_fullinfo_{N_PER_CLASS}perclass.csv")
SAMPLES_DIR = os.path.join(OUT_DIR, "samples")
# OUT_DIR/samples/<class>/sample_00001/{area_5yr.tif, edge_5yr.tif, quicklook.png, meta.txt}


# =========================
# Utilities
# =========================
def list_tifs(folder: str) -> List[str]:
    return sorted(glob.glob(os.path.join(folder, "*.tif")))

def haversine_deg(lon1, lat1, lon2, lat2) -> float:
    """Great-circle distance in degrees (approx)."""
    lon1r, lat1r, lon2r, lat2r = map(np.deg2rad, [lon1, lat1, lon2, lat2])
    dlon = lon2r - lon1r
    dlat = lat2r - lat1r
    a = np.sin(dlat / 2) ** 2 + np.cos(lat1r) * np.cos(lat2r) * np.sin(dlon / 2) ** 2
    c = 2 * np.arcsin(np.sqrt(a))
    return float(np.rad2deg(c))

def center_from_transform(transform: Affine, row: int, col: int) -> Tuple[float, float]:
    x, y = rasterio.transform.xy(transform, row, col, offset="center")
    return float(x), float(y)

def pixel_bounds_from_transform(transform: Affine, row: int, col: int) -> Tuple[float, float, float, float]:
    """
    Pixel bounds (left, bottom, right, top) for north-up transform.
    """
    x_ul, y_ul = transform * (col, row)
    x_lr, y_lr = transform * (col + 1, row + 1)
    left = min(x_ul, x_lr)
    right = max(x_ul, x_lr)
    bottom = min(y_ul, y_lr)
    top = max(y_ul, y_lr)
    return left, bottom, right, top

def build_bounds_index(tif_paths: List[str]) -> List[Tuple[str, float, float, float, float]]:
    out = []
    for p in tif_paths:
        with rasterio.open(p) as ds:
            b = ds.bounds
        out.append((p, b.left, b.bottom, b.right, b.top))
    return out

def find_tile(bounds_index: List[Tuple[str, float, float, float, float]], lon: float, lat: float) -> Optional[str]:
    for p, L, B, R, T in bounds_index:
        if (lon >= L) and (lon < R) and (lat >= B) and (lat < T):
            return p
    return None

def summarize_arr(arr: np.ndarray, nodata: Optional[float]) -> Tuple[float, float, float]:
    """sum, mean, valid_frac"""
    if nodata is None:
        valid = np.isfinite(arr)
    else:
        valid = np.isfinite(arr) & (arr != nodata)
    valid_n = int(valid.sum())
    total_n = arr.size
    valid_frac = valid_n / total_n if total_n > 0 else 0.0
    if valid_n == 0:
        return np.nan, np.nan, 0.0
    vals = arr[valid].astype("float64")
    return float(vals.sum()), float(vals.mean()), float(valid_frac)

def crop_patch_from_tile(tile_path: str, left: float, bottom: float, right: float, top: float):
    """
    Crop raster to bounds. Returns (array, meta, tile_basename) or (None,None,None) on failure.
    """
    with rasterio.open(tile_path) as ds:
        w = ds.window(left, bottom, right, top)
        w = w.round_offsets().round_lengths()

        if w.width <= 0 or w.height <= 0:
            return None, None, None

        if (w.col_off < 0 or w.row_off < 0 or
            w.col_off + w.width > ds.width or
            w.row_off + w.height > ds.height):
            return None, None, None

        arr = ds.read(1, window=w)
        out_transform = ds.window_transform(w)

        meta = {
            "crs": ds.crs,
            "nodata": ds.nodata,
            "dtype": arr.dtype,
            "height": arr.shape[0],
            "width": arr.shape[1],
            "transform": out_transform,
        }
        return arr, meta, os.path.basename(tile_path)

def write_multiband_tif(out_path: str, bands: List[np.ndarray], meta_ref: dict, band_names: List[str]):
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    profile = {
        "driver": "GTiff",
        "height": meta_ref["height"],
        "width": meta_ref["width"],
        "count": len(bands),
        "dtype": meta_ref["dtype"],
        "crs": meta_ref["crs"],
        "transform": meta_ref["transform"],
        "nodata": meta_ref["nodata"],
        "compress": "LZW",
    }
    with rasterio.open(out_path, "w", **profile) as dst:
        for i, arr in enumerate(bands, start=1):
            dst.write(arr, i)
            try:
                dst.set_band_description(i, band_names[i - 1])
            except Exception:
                pass


# =========================
# QUICKLOOK PNG
# =========================
def save_quicklook_png(
    out_png: str,
    area_bands: List[np.ndarray],
    edge_bands: List[np.ndarray],
    years: List[int],
    title_prefix: str = ""
):
    """
    Layout:
      Row 1: Area (first year), Area (last year), Area delta
      Row 2: Edge (first year), Edge (last year), Edge delta

    - Edge uses LogNorm for readability.
    - Deltas use symmetric robust limits around 0.
    """
    y0, y1 = years[0], years[-1]
    A0 = area_bands[0].astype("float64")
    A1 = area_bands[-1].astype("float64")
    E0 = edge_bands[0].astype("float64")
    E1 = edge_bands[-1].astype("float64")

    dA = A1 - A0
    dE = E1 - E0

    def robust_minmax(x, lo=2, hi=98):
        v = x[np.isfinite(x)]
        if v.size == 0:
            return 0.0, 1.0
        return float(np.percentile(v, lo)), float(np.percentile(v, hi))

    # common area scale for A0/A1
    a_vmin, a_vmax = robust_minmax(np.concatenate([A0.ravel(), A1.ravel()]))

    # edge log scale
    e_pos = np.concatenate([E0.ravel(), E1.ravel()])
    e_pos = e_pos[np.isfinite(e_pos) & (e_pos > 0)]
    if e_pos.size == 0:
        e_vmin, e_vmax = 1e-6, 1.0
    else:
        e_vmin = float(np.percentile(e_pos, 2))
        e_vmax = float(np.percentile(e_pos, 98))
        e_vmin = max(e_vmin, 1e-6)
        e_vmax = max(e_vmax, e_vmin * 10)

    def sym_limit(x):
        v = x[np.isfinite(x)]
        if v.size == 0:
            return 1.0
        m = float(np.percentile(np.abs(v), 98))
        return max(m, 1e-6)

    dA_lim = sym_limit(dA)
    dE_lim = sym_limit(dE)

    fig, axes = plt.subplots(2, 3, figsize=(12, 8), dpi=150)
    fig.suptitle(title_prefix, fontsize=12)

    # Area
    im00 = axes[0, 0].imshow(A0, vmin=a_vmin, vmax=a_vmax)
    axes[0, 0].set_title(f"Area {y0}")
    axes[0, 0].axis("off")
    plt.colorbar(im00, ax=axes[0, 0], fraction=0.046, pad=0.04)

    im01 = axes[0, 1].imshow(A1, vmin=a_vmin, vmax=a_vmax)
    axes[0, 1].set_title(f"Area {y1}")
    axes[0, 1].axis("off")
    plt.colorbar(im01, ax=axes[0, 1], fraction=0.046, pad=0.04)

    im02 = axes[0, 2].imshow(dA, vmin=-dA_lim, vmax=dA_lim)
    axes[0, 2].set_title(f"ΔArea {y1}-{y0}")
    axes[0, 2].axis("off")
    plt.colorbar(im02, ax=axes[0, 2], fraction=0.046, pad=0.04)

    # Edge (log)
    im10 = axes[1, 0].imshow(E0, norm=LogNorm(vmin=e_vmin, vmax=e_vmax))
    axes[1, 0].set_title(f"Edge {y0} (log)")
    axes[1, 0].axis("off")
    plt.colorbar(im10, ax=axes[1, 0], fraction=0.046, pad=0.04)

    im11 = axes[1, 1].imshow(E1, norm=LogNorm(vmin=e_vmin, vmax=e_vmax))
    axes[1, 1].set_title(f"Edge {y1} (log)")
    axes[1, 1].axis("off")
    plt.colorbar(im11, ax=axes[1, 1], fraction=0.046, pad=0.04)

    im12 = axes[1, 2].imshow(dE, vmin=-dE_lim, vmax=dE_lim)
    axes[1, 2].set_title(f"ΔEdge {y1}-{y0}")
    axes[1, 2].axis("off")
    plt.colorbar(im12, ax=axes[1, 2], fraction=0.046, pad=0.04)

    plt.tight_layout(rect=[0, 0, 1, 0.96])
    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)


# =========================
# Load dynamics stack (0.01°)
# =========================
def load_dynamics_stack():
    arrays = []
    transform = None
    nodata = None
    H = W = None

    for cname in CLASSES:
        fp = DYN_FILES[cname]
        if not os.path.exists(fp):
            raise FileNotFoundError(fp)

        with rasterio.open(fp) as ds:
            if transform is None:
                transform = ds.transform
                nodata = ds.nodata
                H, W = ds.height, ds.width
            else:
                if ds.transform != transform or ds.height != H or ds.width != W:
                    raise RuntimeError(f"Dynamics rasters not aligned: {fp}")

            arrays.append(ds.read(1).astype("float64"))

    stack = np.stack(arrays, axis=0)  # (5, H, W)
    return stack, transform, nodata

def compute_winner(stack: np.ndarray, nodata: Optional[float]):
    if nodata is None:
        valid = np.all(np.isfinite(stack), axis=0)
    else:
        valid = np.all(np.isfinite(stack) & (stack != nodata), axis=0)

    stack_pos = np.clip(stack, 0, None)  # just in case
    winner_idx = np.argmax(stack_pos, axis=0)
    winner_val = np.take_along_axis(stack_pos, winner_idx[None, ...], axis=0)[0]

    valid = valid & (winner_val > MIN_WINNER_VALUE)
    return valid, winner_idx, winner_val, stack_pos


# =========================
# Sampling dominant-class 0.01° pixels
# =========================
def sample_pixels_per_class(valid, winner_idx, winner_val, stack_pos, transform) -> pd.DataFrame:
    np.random.seed(SEED)
    random.seed(SEED)

    records = []

    for ci, cname in enumerate(CLASSES):
        m = valid & (winner_idx == ci)
        rr, cc = np.where(m)
        n_cand = rr.size
        if n_cand == 0:
            print(f"[Warn] No candidates for class {cname}.")
            continue

        order = np.random.permutation(n_cand)
        rr = rr[order]
        cc = cc[order]

        selected_coords = []
        idx_ptr = 0
        tries = 0

        while len(selected_coords) < N_PER_CLASS and tries < MAX_TRIES_PER_CLASS and idx_ptr < n_cand:
            tries += 1
            r = int(rr[idx_ptr])
            c = int(cc[idx_ptr])
            idx_ptr += 1

            lon, lat = center_from_transform(transform, r, c)

            # enforce within-class spacing
            ok = True
            for slon, slat in selected_coords:
                if haversine_deg(lon, lat, slon, slat) < MIN_SEPARATION_DEG:
                    ok = False
                    break
            if not ok:
                continue

            selected_coords.append((lon, lat))
            left, bottom, right, top = pixel_bounds_from_transform(transform, r, c)

            rec = {
                "class": cname,
                "row_001": r,
                "col_001": c,
                "center_lon": lon,
                "center_lat": lat,
                "pixel_left": left,
                "pixel_bottom": bottom,
                "pixel_right": right,
                "pixel_top": top,
                "winner_value_001": float(winner_val[r, c]),
            }
            for k_i, k_name in enumerate(CLASSES):
                rec[f"dyn_{k_name}_001"] = float(stack_pos[k_i, r, c])

            records.append(rec)

        print(f"Class {cname}: selected {len(selected_coords)} / {N_PER_CLASS} (candidates={n_cand})")

    df = pd.DataFrame(records)
    df.insert(0, "sample_id", np.arange(1, len(df) + 1))
    return df


# =========================
# Tile indices for high-res area/edge (per year)
# =========================
def build_tile_bounds_by_year(folder_tmpl: str) -> Dict[int, List[Tuple[str, float, float, float, float]]]:
    out = {}
    for y in YEARS:
        folder = folder_tmpl.format(year=y)
        files = list_tifs(folder)
        if not files:
            raise RuntimeError(f"No tif in {folder}")
        out[y] = build_bounds_index(files)
        print(f"Indexed {len(out[y])} tiles for {folder}")
    return out


# =========================
# Extract patches + write per-sample folder (area, edge, quicklook)
# =========================
def extract_and_write(df: pd.DataFrame) -> pd.DataFrame:
    os.makedirs(SAMPLES_DIR, exist_ok=True)

    area_bounds = build_tile_bounds_by_year(AREA_DIR_TMPL)
    edge_bounds = build_tile_bounds_by_year(EDGE_DIR_TMPL)

    out_rows = []
    saved = 0
    skipped = 0

    for _, r in df.iterrows():
        sid = int(r["sample_id"])
        cname = str(r["class"])
        lon = float(r["center_lon"])
        lat = float(r["center_lat"])
        left = float(r["pixel_left"])
        bottom = float(r["pixel_bottom"])
        right = float(r["pixel_right"])
        top = float(r["pixel_top"])

        sample_folder = os.path.join(SAMPLES_DIR, cname, f"sample_{sid:05d}")
        os.makedirs(sample_folder, exist_ok=True)

        rec = r.to_dict()
        rec["sample_folder"] = sample_folder

        area_bands: List[np.ndarray] = []
        edge_bands: List[np.ndarray] = []
        area_meta_ref = None
        edge_meta_ref = None

        ok = True

        for y in YEARS:
            a_tile = find_tile(area_bounds[y], lon, lat)
            e_tile = find_tile(edge_bounds[y], lon, lat)
            if a_tile is None or e_tile is None:
                ok = False
                break

            a_arr, a_meta, a_tile_name = crop_patch_from_tile(a_tile, left, bottom, right, top)
            e_arr, e_meta, e_tile_name = crop_patch_from_tile(e_tile, left, bottom, right, top)
            if a_arr is None or e_arr is None:
                ok = False
                break

            # shape consistency across years
            if area_meta_ref is None:
                area_meta_ref = a_meta
            else:
                if (a_meta["height"], a_meta["width"]) != (area_meta_ref["height"], area_meta_ref["width"]):
                    ok = False
                    break

            if edge_meta_ref is None:
                edge_meta_ref = e_meta
            else:
                if (e_meta["height"], e_meta["width"]) != (edge_meta_ref["height"], edge_meta_ref["width"]):
                    ok = False
                    break

            # stats
            a_sum, a_mean, a_v = summarize_arr(a_arr, a_meta["nodata"])
            e_sum, e_mean, e_v = summarize_arr(e_arr, e_meta["nodata"])

            rec[f"area_tile_{y}"] = a_tile_name
            rec[f"edge_tile_{y}"] = e_tile_name

            rec[f"area_sum_{y}"] = a_sum
            rec[f"area_mean_{y}"] = a_mean
            rec[f"area_validfrac_{y}"] = a_v

            rec[f"edge_sum_{y}"] = e_sum
            rec[f"edge_mean_{y}"] = e_mean
            rec[f"edge_validfrac_{y}"] = e_v

            area_bands.append(a_arr)
            edge_bands.append(e_arr)

        if not ok or len(area_bands) != len(YEARS) or len(edge_bands) != len(YEARS):
            skipped += 1
            continue

        # Write both tifs into the same sample folder
        area_out = os.path.join(sample_folder, "area_5yr.tif")
        edge_out = os.path.join(sample_folder, "edge_5yr.tif")

        write_multiband_tif(area_out, area_bands, area_meta_ref, [str(y) for y in YEARS])
        write_multiband_tif(edge_out, edge_bands, edge_meta_ref, [str(y) for y in YEARS])

        rec["area_patch_path"] = area_out
        rec["edge_patch_path"] = edge_out

        # Quicklook
        quicklook_out = os.path.join(sample_folder, "quicklook.png")
        title_prefix = f"{cname} | sample {sid:05d} | ({lon:.4f}, {lat:.4f})"
        save_quicklook_png(
            out_png=quicklook_out,
            area_bands=area_bands,
            edge_bands=edge_bands,
            years=YEARS,
            title_prefix=title_prefix
        )
        rec["quicklook_png_path"] = quicklook_out

        # meta.txt (handy)
        meta_txt = os.path.join(sample_folder, "meta.txt")
        with open(meta_txt, "w", encoding="utf-8") as f:
            f.write(f"sample_id: {sid}\n")
            f.write(f"class: {cname}\n")
            f.write(f"center_lon, center_lat: {lon}, {lat}\n")
            f.write(f"0.01deg pixel bounds (L,B,R,T): {left}, {bottom}, {right}, {top}\n\n")
            f.write("0.01deg dynamics values:\n")
            for k in CLASSES:
                f.write(f"  {k}: {rec.get('dyn_'+k+'_001', np.nan)}\n")
            f.write("\nPer-year tiles and sums:\n")
            for y in YEARS:
                f.write(f"\nYear {y}\n")
                f.write(f"  area_tile: {rec.get(f'area_tile_{y}','')}\n")
                f.write(f"  edge_tile: {rec.get(f'edge_tile_{y}','')}\n")
                f.write(f"  area_sum: {rec.get(f'area_sum_{y}',np.nan)}\n")
                f.write(f"  edge_sum: {rec.get(f'edge_sum_{y}',np.nan)}\n")
        rec["meta_txt_path"] = meta_txt

        out_rows.append(rec)
        saved += 1
        if saved % 50 == 0:
            print(f"Saved {saved} samples (skipped={skipped})")

    print(f"Done. Saved={saved}, Skipped={skipped}")
    return pd.DataFrame(out_rows)


# =========================
# MAIN
# =========================
def main():
    os.makedirs(OUT_DIR, exist_ok=True)

    stack, transform, nodata = load_dynamics_stack()
    valid, winner_idx, winner_val, stack_pos = compute_winner(stack, nodata)

    df_points = sample_pixels_per_class(valid, winner_idx, winner_val, stack_pos, transform)
    print(f"Total sampled points (before extraction): {len(df_points)}")

    df_full = extract_and_write(df_points)
    df_full.to_csv(OUT_CSV, index=False)
    print(f"Saved CSV: {OUT_CSV}")
    print(df_full.head())


if __name__ == "__main__":
    main()

In [7]:
import os
import re
from pathlib import Path

# =======================
# USER PATH
# =======================
ROOT = r"G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa"
META_FILENAMES = ("meta.txty", "meta.txt")
ADD_CENTER_POINT = True

# =======================
# REGEX PATTERNS
# =======================
RE_SAMPLE_ID = re.compile(r"sample_id\s*:\s*(.+)", re.IGNORECASE)
RE_CLASS = re.compile(r"class\s*:\s*(.+)", re.IGNORECASE)
RE_CENTER = re.compile(
    r"center_lon\s*,\s*center_lat\s*:\s*([-\d\.eE]+)\s*,\s*([-\d\.eE]+)",
    re.IGNORECASE
)
RE_BOUNDS = re.compile(
    r"0\.01deg\s+pixel\s+bounds\s*\(L,B,R,T\)\s*:\s*([-\d\.eE]+)\s*,\s*([-\d\.eE]+)\s*,\s*([-\d\.eE]+)\s*,\s*([-\d\.eE]+)",
    re.IGNORECASE
)

# =======================
# KML GENERATOR
# =======================
def make_kml(name, bounds, center=None):
    west, south, east, north = bounds

    polygon_coords = f"""
            {west},{south},0
            {east},{south},0
            {east},{north},0
            {west},{north},0
            {west},{south},0
    """

    center_point = ""
    if ADD_CENTER_POINT and center is not None:
        clon, clat = center
        center_point = f"""
  <Placemark>
    <name>{name}_center</name>
    <Point>
      <coordinates>{clon},{clat},0</coordinates>
    </Point>
  </Placemark>
"""

    return f"""<?xml version="1.0" encoding="UTF-8"?>
<kml xmlns="http://www.opengis.net/kml/2.2">
<Document>
  <name>{name}</name>

  <Style id="box">
    <LineStyle>
      <color>ff0000ff</color>
      <width>2</width>
    </LineStyle>
    <PolyStyle>
      <color>3f0000ff</color>
    </PolyStyle>
  </Style>

  <Placemark>
    <name>{name}</name>
    <styleUrl>#box</styleUrl>
    <Polygon>
      <outerBoundaryIs>
        <LinearRing>
          <coordinates>
{polygon_coords}
          </coordinates>
        </LinearRing>
      </outerBoundaryIs>
    </Polygon>
  </Placemark>
{center_point}
</Document>
</kml>
"""

# =======================
# META PARSER
# =======================
def parse_meta(meta_path):
    sample_id = None
    cls = None
    center = None
    bounds = None

    for line in meta_path.read_text(encoding="utf-8", errors="ignore").splitlines():
        if m := RE_SAMPLE_ID.search(line):
            sample_id = m.group(1).strip()
        elif m := RE_CLASS.search(line):
            cls = m.group(1).strip()
        elif m := RE_CENTER.search(line):
            center = (float(m.group(1)), float(m.group(2)))
        elif m := RE_BOUNDS.search(line):
            bounds = tuple(map(float, m.groups()))

    return sample_id, cls, center, bounds

# =======================
# MAIN
# =======================
def main():
    root = Path(ROOT)
    count = 0

    for subdir, _, files in os.walk(root):
        subdir = Path(subdir)
        if subdir == root:
            continue

        meta_path = None
        for name in META_FILENAMES:
            p = subdir / name
            if p.exists():
                meta_path = p
                break

        if meta_path is None:
            continue

        sample_id, cls, center, bounds = parse_meta(meta_path)

        if bounds is None:
            print(f"[SKIP] No bounds in {meta_path}")
            continue

        # output KML IN THE SAME SUBFOLDER
        kml_name = subdir.name
        kml_path = subdir / f"{kml_name}.kml"

        kml_text = make_kml(kml_name, bounds, center)
        kml_path.write_text(kml_text, encoding="utf-8")

        print(f"[OK] {kml_path}")
        count += 1

    print(f"\nDone. KML files created: {count}")

if __name__ == "__main__":
    main()

[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01010_A\sample_01010_A.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01019_A\sample_01019_A.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01026_B\sample_01026_B.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01029_A\sample_01029_A.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01030_B\sample_01030_B.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01038_B\sample_01038_B.kml
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\aaa\sample_01040_A\sample_01040_A.kml
[OK] G:\Hangkai\Global_Fore

In [6]:
import os
from pathlib import Path

import numpy as np
import rasterio
import matplotlib.pyplot as plt

# =======================
# USER SETTINGS
# =======================
ROOT = r"G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output"
AREA_NAME = "area_5yr.tif"
EDGE_NAME = "edge_5yr.tif"
YEARS = [2000, 2005, 2010, 2015, 2020]

# 生成预览图文件名（每个 subfolder 内）
OUT_PNG_NAME = "preview_area_edge_5yr.png"

# 分位数拉伸，避免极端值影响对比度
STRETCH_Q = (2, 98)   # (low, high) percentiles


def read_5band_tif(tif_path: Path):
    """Read a 5-band GeoTIFF into a (5, H, W) float32 array."""
    with rasterio.open(tif_path) as ds:
        if ds.count < 5:
            raise ValueError(f"{tif_path} has {ds.count} bands, expected >= 5.")
        arr = ds.read(list(range(1, 6))).astype(np.float32)  # (5, H, W)
    return arr


def mask_invalid(arr):
    """Mask values <= 0 as invalid (NaN)."""
    out = arr.copy()
    out[out <= 0] = np.nan
    return out


def robust_vmin_vmax(arr, q=(2, 98)):
    """Compute robust display range ignoring NaNs."""
    flat = arr[np.isfinite(arr)]
    if flat.size == 0:
        return 0.0, 1.0
    vmin, vmax = np.percentile(flat, q)
    if np.isclose(vmin, vmax):
        # fallback if almost constant
        vmin = float(np.nanmin(arr)) if np.isfinite(arr).any() else 0.0
        vmax = float(np.nanmax(arr)) if np.isfinite(arr).any() else 1.0
        if np.isclose(vmin, vmax):
            vmax = vmin + 1.0
    return float(vmin), float(vmax)


def make_preview(area_5, edge_5, out_png: Path, title: str):
    """
    area_5, edge_5: (5, H, W) with NaNs for invalid
    Save a 5x2 grid: rows=years, cols=[area, edge]
    """
    # set display ranges separately for area & edge (robust)
    area_vmin, area_vmax = robust_vmin_vmax(area_5, STRETCH_Q)
    edge_vmin, edge_vmax = robust_vmin_vmax(edge_5, STRETCH_Q)

    fig, axes = plt.subplots(
        nrows=5, ncols=2,
        figsize=(10, 18),
        dpi=150,
        constrained_layout=True
    )

    for i, yr in enumerate(YEARS):
        ax_a = axes[i, 0]
        ax_e = axes[i, 1]

        a = area_5[i]
        e = edge_5[i]

        im_a = ax_a.imshow(a, vmin=area_vmin, vmax=area_vmax)
        ax_a.set_title(f"Area {yr}")
        ax_a.set_axis_off()

        im_e = ax_e.imshow(e, vmin=edge_vmin, vmax=edge_vmax)
        ax_e.set_title(f"Edge {yr}")
        ax_e.set_axis_off()

    # colorbars (one for each column)
    cbar_a = fig.colorbar(im_a, ax=axes[:, 0], fraction=0.02, pad=0.01)
    cbar_a.set_label("Area value (masked <= 0)")

    cbar_e = fig.colorbar(im_e, ax=axes[:, 1], fraction=0.02, pad=0.01)
    cbar_e.set_label("Edge value (masked <= 0)")

    fig.suptitle(title, fontsize=14)
    fig.savefig(out_png, bbox_inches="tight")
    plt.close(fig)


def main():
    root = Path(ROOT)
    if not root.exists():
        raise FileNotFoundError(f"ROOT not found: {root}")

    n_done, n_skip = 0, 0

    for subdir, _, _ in os.walk(root):
        subdir = Path(subdir)
        if subdir == root:
            continue

        area_path = subdir / AREA_NAME
        edge_path = subdir / EDGE_NAME
        if not (area_path.exists() and edge_path.exists()):
            continue

        try:
            area_5 = mask_invalid(read_5band_tif(area_path))
            edge_5 = mask_invalid(read_5band_tif(edge_path))

            out_png = subdir / OUT_PNG_NAME
            title = f"{subdir.name} | {AREA_NAME} & {EDGE_NAME}"
            make_preview(area_5, edge_5, out_png, title)

            print(f"[OK] {out_png}")
            n_done += 1
        except Exception as e:
            print(f"[SKIP] {subdir} -> {e}")
            n_skip += 1

    print(f"\nDone. previews created: {n_done}, skipped/errors: {n_skip}")


if __name__ == "__main__":
    main()

[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00502\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00507\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00508\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00538\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00539\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_00540\preview_area_edge_5yr.png
[OK] G:\Hangkai\Global_Forest_edge_mapping_data\validation_samples_0p01deg_per_sample_folder\samples\output\sample_005