#Chl-a & NDCL

In [None]:
# NDCI -> Chlorophyll-a (µg/L) + NDCI mapping script (single-season, Sentinel-2)
# Requirements: rasterio, numpy, matplotlib, pyproj, scikit-image

import os
import numpy as np
import rasterio
from rasterio.transform import array_bounds
from rasterio.warp import reproject, Resampling
from pyproj import CRS, Transformer
import matplotlib.pyplot as plt
from skimage import morphology
import matplotlib.image as mpimg
import warnings
warnings.filterwarnings('ignore')

# ------------------ USER SETTINGS ------------------
B4_PATH = "/content/drive/MyDrive/Pandas.Monsoon/Band_4.tif"
B5_PATH = "/content/drive/MyDrive/Pandas/Chl-a/Ysoon/Band_5.tif"
B3_PATH = "/content/drive/MyDrive/Pan/3.Monsoon/Band_3.tif"
B8_PATH = "/content/drive/MyDrive/Pta/2022/3.Monsoon/Band_8.tif"

OUT_DIR = "/content/drive/MyDrive/Pandast_Monsoon"
os.makedirs(OUT_DIR, exist_ok=True)

# NDCI -> Chl conversion
CONVERT_METHOD = 'piecewise'  # 'piecewise' or 'regression'
reg_a = None                  # used only if method='regression'
reg_b = None

# Masks
NDWI_THRESHOLD = 0.05
VEG_THRESHOLD = 0.18
MIN_PHUMDIS_PIXELS = 30

# ------------------ HELPER FUNCTIONS ------------------
def read_band(path):
    """Read band, convert to float, scale 0–1 if needed, set nodata to NaN."""
    with rasterio.open(path) as src:
        arr = src.read(1).astype('float32')
        prof = src.profile.copy()
        nod = src.nodata
    if nod is not None:
        arr[arr == nod] = np.nan
    finite = np.isfinite(arr)
    if finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr[finite] = arr[finite] / 10000.0  # scale reflectance
    return np.clip(arr, 0.0, 1.5), prof

def compute_ndci(b5, b4):
    eps = 1e-9
    return np.divide(
        (b5 - b4), (b5 + b4 + eps),
        out=np.full_like(b5, np.nan),
        where=np.isfinite(b5) & np.isfinite(b4)
    )

def compute_water_mask(b3, b8, ndwi_threshold=NDWI_THRESHOLD):
    eps = 1e-9
    ndwi = np.divide(
        (b3 - b8), (b3 + b8 + eps),
        out=np.full_like(b3, np.nan),
        where=np.isfinite(b3) & np.isfinite(b8)
    )
    water = (ndwi > ndwi_threshold) & np.isfinite(ndwi)
    return water, ndwi

def compute_phumdis_mask(b8, b4=None, b3=None,
                         veg_threshold=VEG_THRESHOLD,
                         water_mask=None,
                         min_size=MIN_PHUMDIS_PIXELS):
    """Simple vegetation/‘phumdis’ mask using NIR–Red or NIR–Green index."""
    eps = 1e-9
    if b4 is not None:
        num = b8 - b4
        den = b8 + b4 + eps
    else:
        num = b8 - b3
        den = b8 + b3 + eps

    vi = np.divide(
        num, den,
        out=np.full_like(num, np.nan),
        where=np.isfinite(num) & np.isfinite(den)
    )

    raw = (vi >= veg_threshold) & np.isfinite(vi)
    if water_mask is not None:
        raw &= water_mask

    cleaned = morphology.remove_small_objects(raw.astype(bool), min_size=min_size)
    cleaned = morphology.remove_small_holes(cleaned, area_threshold=int(min_size/2))
    return cleaned.astype('uint8'), vi

def convert_ndci_to_chl(ndci_arr, method='piecewise', a=None, b=None):
    """Convert NDCI to approximate Chl-a (µg/L)."""
    if method == 'regression':
        chl = a * ndci_arr + b
        chl[~np.isfinite(chl)] = np.nan
        return chl.astype('float32')

    # --- Piecewise (class-based) relationship ---
    chl = np.full_like(ndci_arr, np.nan, dtype='float32')

    # Very low NDCI -> low Chl
    chl[np.isfinite(ndci_arr) & (ndci_arr <= -0.1)] = 3.5

    # -0.1 to 0.1 -> 7.5–25
    mask = np.isfinite(ndci_arr) & (ndci_arr > -0.1) & (ndci_arr <= 0.1)
    if np.any(mask):
        x = ndci_arr[mask]
        chl[mask] = 7.5 + (x + 0.1)/0.2 * (25 - 7.5)

    # 0.1 to 0.2 -> 25–33
    mask = np.isfinite(ndci_arr) & (ndci_arr > 0.1) & (ndci_arr <= 0.2)
    if np.any(mask):
        x = ndci_arr[mask]
        chl[mask] = 25 + (x - 0.1)/0.1 * (33 - 25)

    # 0.2 to 0.5 -> 33–50
    mask = np.isfinite(ndci_arr) & (ndci_arr > 0.2) & (ndci_arr <= 0.5)
    if np.any(mask):
        x = ndci_arr[mask]
        chl[mask] = 33 + (x - 0.2)/0.3 * (50 - 33)

    # Very high NDCI
    chl[np.isfinite(ndci_arr) & (ndci_arr > 0.5)] = 60

    return chl

def add_scalebar(ax, extent, length_km=2, linewidth=3, fontsize=12):
    """
    Add a simple scale bar in km in the lower-left corner of a lon/lat map.
    extent = (lon_min, lon_max, lat_min, lat_max)
    """
    lon_min, lon_max, lat_min, lat_max = extent
    dx = lon_max - lon_min
    dy = lat_max - lat_min
    lat_center = (lat_min + lat_max) / 2.0

    # approximate meters per degree longitude at this latitude
    m_per_deg_lon = 111320 * np.cos(np.deg2rad(lat_center))
    length_deg = (length_km * 1000.0) / m_per_deg_lon

    # start and end of bar
    x_start = lon_min + 0.05 * dx
    x_end = x_start + length_deg
    y = lat_min + 0.05 * dy

    ax.plot([x_start, x_end], [y, y], color='k', linewidth=linewidth)
    ax.text((x_start + x_end)/2.0,
            y + 0.01 * dy,
            f"{length_km} km",
            ha='center', va='bottom',
            fontsize=fontsize,
            fontweight='bold')

# ------------------ MAIN WORKFLOW ------------------

# --- Read B4 and use it as reference grid ---
B4, prof_b4 = read_band(B4_PATH)
ref_shape = B4.shape
ref_transform = prof_b4['transform']
ref_crs = prof_b4['crs']

# --- Read and align B5 ---
B5, prof_b5 = read_band(B5_PATH)
if B5.shape != ref_shape:
    print(f"Warning: B5 shape {B5.shape} != B4 shape {ref_shape}. Reprojecting B5.")
    B5_reprojected = np.empty(ref_shape, dtype=B5.dtype)
    reproject(
        source=B5,
        destination=B5_reprojected,
        src_transform=prof_b5['transform'],
        src_crs=prof_b5['crs'],
        dst_transform=ref_transform,
        dst_crs=ref_crs,
        resampling=Resampling.bilinear
    )
    B5 = B5_reprojected

# --- Read and align B3 ---
B3, prof_b3 = read_band(B3_PATH)
if B3.shape != ref_shape:
    print(f"Warning: B3 shape {B3.shape} != B4 shape {ref_shape}. Reprojecting B3.")
    B3_reprojected = np.empty(ref_shape, dtype=B3.dtype)
    reproject(
        source=B3,
        destination=B3_reprojected,
        src_transform=prof_b3['transform'],
        src_crs=prof_b3['crs'],
        dst_transform=ref_transform,
        dst_crs=ref_crs,
        resampling=Resampling.bilinear
    )
    B3 = B3_reprojected

# --- Read and align B8 ---
B8, prof_b8 = read_band(B8_PATH)
if B8.shape != ref_shape:
    print(f"Warning: B8 shape {B8.shape} != B4 shape {ref_shape}. Reprojecting B8.")
    B8_reprojected = np.empty(ref_shape, dtype=B8.dtype)
    reproject(
        source=B8,
        destination=B8_reprojected,
        src_transform=prof_b8['transform'],
        src_crs=prof_b8['crs'],
        dst_transform=ref_transform,
        dst_crs=ref_crs,
        resampling=Resampling.bilinear
    )
    B8 = B8_reprojected

# Use B4 profile for everything
prof = prof_b4

# --- Indices & masks ---
ndci = compute_ndci(B5, B4)
water_mask, ndwi = compute_water_mask(B3, B8)
phumdis_mask, veg_index = compute_phumdis_mask(
    B8, b4=B4, b3=B3, water_mask=water_mask
)

# Mask phumdis out of NDCI
ndci_masked = ndci.copy()
ndci_masked[phumdis_mask == 1] = np.nan

# Approximate Chl-a
chl = convert_ndci_to_chl(ndci_masked,
                          method=CONVERT_METHOD,
                          a=reg_a,
                          b=reg_b)

# ------------------ PLOTTING SETTINGS ------------------
NORTH_ARROW_PATH = "/content/drive/MyDrive/Pandas/Chl-a/Yearly Cliped_Data/Compass.jpg"
NORTH_X = 0.13
NORTH_Y = 0.80
NORTH_SIZE = 0.13

FONT_TICK = 14
FONT_CBAR = 14
FONT_CBAR_LABEL = 16

# ---- Compute extent in lon/lat ----
height, width = chl.shape
left, bottom, right, top = array_bounds(height, width, prof['transform'])
src_crs = CRS.from_user_input(prof['crs'])

if src_crs.is_geographic:
    extent = (left, right, bottom, top)
else:
    transformer = Transformer.from_crs(src_crs, "EPSG:4326", always_xy=True)
    lon_left, lat_bottom = transformer.transform(left, bottom)
    lon_right, lat_top = transformer.transform(right, top)
    extent = (lon_left, lon_right, lat_bottom, lat_top)

# Shared tick locations
xticks = np.linspace(extent[0], extent[1], 5)
yticks = np.linspace(extent[2], extent[3], 5)

# ====================================================
# 1) CHLOROPHYLL-a MAP
# ====================================================
finite_chl = chl[np.isfinite(chl)]
vmin_chl = np.nanpercentile(finite_chl, 2)
vmax_chl = np.nanpercentile(finite_chl, 98)

fig, ax = plt.subplots(figsize=(10, 10), dpi=600)

im = ax.imshow(chl, cmap='viridis', origin='upper',
               vmin=vmin_chl, vmax=vmax_chl, extent=extent)

ax.set_xticks(xticks)
ax.set_yticks(yticks)

ax.set_xticklabels(
    [f"{x:.4f}°E" for x in xticks],
    rotation=15,
    fontsize=18,
    fontweight='bold'
)
ax.set_yticklabels(
    [f"{y:.4f}°N" for y in yticks],
    fontsize=18,
    fontweight='bold'
)

ax.set_xlabel("")
ax.set_ylabel("")
ax.tick_params(axis='both', width=1.5, length=6)

# North arrow
if os.path.exists(NORTH_ARROW_PATH):
    north_img = mpimg.imread(NORTH_ARROW_PATH)
    ax_n = fig.add_axes([NORTH_X, NORTH_Y, NORTH_SIZE, NORTH_SIZE])
    ax_n.imshow(north_img)
    ax_n.axis("off")

# Scale bar (e.g., 2 km)
add_scalebar(ax, extent, length_km=2, linewidth=3, fontsize=14)

# Colorbar
cbar = fig.colorbar(im, ax=ax, fraction=0.04, pad=0.02)
cbar.ax.tick_params(labelsize=FONT_CBAR, width=1.5)
for label in cbar.ax.get_yticklabels():
    label.set_fontweight('bold')
cbar.outline.set_linewidth(1.5)
cbar.set_label("Chlorophyll-a (µg L⁻¹)",
               fontsize=FONT_CBAR_LABEL,
               fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "Chl_NDCI_map.png"),
            dpi=300, bbox_inches='tight')
plt.show()

# ====================================================
# 2) NDCI MAP (same style)
# ====================================================
finite_ndci = ndci_masked[np.isfinite(ndci_masked)]
vmin_ndci = np.nanpercentile(finite_ndci, 2)
vmax_ndci = np.nanpercentile(finite_ndci, 98)

fig2, ax2 = plt.subplots(figsize=(10, 10), dpi=600)

im2 = ax2.imshow(ndci_masked, cmap='turbo', origin='upper',
                 vmin=vmin_ndci, vmax=vmax_ndci, extent=extent)

ax2.set_xticks(xticks)
ax2.set_yticks(yticks)

ax2.set_xticklabels(
    [f"{x:.4f}°E" for x in xticks],
    rotation=15,
    fontsize=18,
    fontweight='bold'
)
ax2.set_yticklabels(
    [f"{y:.4f}°N" for y in yticks],
    fontsize=18,
    fontweight='bold'
)

ax2.set_xlabel("")
ax2.set_ylabel("")
ax2.tick_params(axis='both', width=1.5, length=6)

# North arrow
if os.path.exists(NORTH_ARROW_PATH):
    north_img = mpimg.imread(NORTH_ARROW_PATH)
    ax2_n = fig2.add_axes([NORTH_X, NORTH_Y, NORTH_SIZE, NORTH_SIZE])
    ax2_n.imshow(north_img)
    ax2_n.axis("off")

# Scale bar for NDCI map (same length)
add_scalebar(ax2, extent, length_km=2, linewidth=3, fontsize=12)

# Colorbar for NDCI
cbar2 = fig2.colorbar(im2, ax=ax2, fraction=0.04, pad=0.02)
cbar2.ax.tick_params(labelsize=FONT_CBAR, width=1.5)
for label in cbar2.ax.get_yticklabels():
    label.set_fontweight('bold')
cbar2.outline.set_linewidth(1.5)
cbar2.set_label("NDCI",
                fontsize=FONT_CBAR_LABEL,
                fontweight='bold')

plt.tight_layout()
plt.savefig(os.path.join(OUT_DIR, "NDCI_map.png"),
            dpi=300, bbox_inches='tight')
plt.show()


#ML

#RF_Before Hypertuning

In [None]:
# =====================================================
# IMPORTS
# =====================================================
import os
import glob
import numpy as np
import rasterio
from rasterio.features import geometry_mask
import geopandas as gpd

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.colors as mcolors

from skimage.morphology import dilation, disk   # for bloom expansion

# =====================================================
# CONFIG  (EDIT THESE PATHS)
# =====================================================
SCENES_DIR = "/content/drive/MyDrive"
SHP_PATH   = "/content/drive/MyDrive"
SAMPLES_PER_SCENE = 5000   # random samples per scene for ML

# NDVI thresholds for phumdis
VEG_SPARSE = 0.20
VEG_DENSE  = 0.40

# Clean-water threshold fixed, bloom threshold from data
CHL_T1_FIXED      = 20.0     # clean vs nutrient-rich
BLOOM_PERCENTILE  = 85       # bloom = top 15% of Chl inside lake  (Option 1)

# Hybrid bloom rule thresholds (Option 2)
NDCI_BLOOM_MIN = 0.15
NDVI_BLOOM_MIN = 0.25

# Morphological expansion radius in pixels (Option 3)
BLOOM_DILATE_RADIUS = 2      # 2–3 is reasonable

# =====================================================
# HELPER: Read stacked Sentinel-2 TIFF (B2–B12)
# =====================================================
def read_stack_reflectance(path):
    print(f"Reading {path}...")
    with rasterio.open(path) as src:
        stack = src.read().astype("float32")
        prof  = src.profile.copy()

    finite = np.isfinite(stack)
    if finite.any() and float(np.nanmax(stack[finite])) > 1.5:
        stack = np.where(finite, stack / 10000.0, np.nan)

    return np.clip(stack, 0.0, 1.5), prof

# =====================================================
# HELPER: Parse "2021_Monsoon.tif" → (year, season, base)
# =====================================================
def parse_scene_name(path):
    base  = os.path.basename(path).replace(".tif", "")
    parts = base.split("_", 1)
    if len(parts) == 2:
        year, season = parts
    else:
        year, season = "Unknown", base
    return year, season, base

# =====================================================
# LOAD LAKE SHAPEFILE
# =====================================================
print("Loading lake mask shapefile:", SHP_PATH)
gdf_lake = gpd.read_file(SHP_PATH)

scene_files = sorted(glob.glob(os.path.join(SCENES_DIR, "*.tif")))
print("Found scenes:", len(scene_files))
if not scene_files:
    raise RuntimeError("No .tif files found in SCENES_DIR")

# =====================================================
# PASS 1: COLLECT Chl VALUES TO SET BLOOM THRESHOLD
# =====================================================
all_chl_values = []

for scene_path in scene_files:
    _, _, base = parse_scene_name(scene_path)
    print(f"\n[PASS 1] Collecting Chl from: {base}")

    stack, prof = read_stack_reflectance(scene_path)
    _, H, W = stack.shape

    gdf_proj = gdf_lake.to_crs(prof["crs"])
    lake_mask = geometry_mask(
        gdf_proj.geometry,
        transform=prof["transform"],
        invert=True,
        out_shape=(H, W)
    )

    B4 = stack[2]   # Red
    B5 = stack[3]   # RE1
    eps = 1e-9

    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl  = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)

    valid = lake_mask & np.isfinite(chl)
    if np.any(valid):
        all_chl_values.append(chl[valid])

if not all_chl_values:
    raise RuntimeError("No valid Chl values inside lake in any scene.")

all_chl_values = np.concatenate(all_chl_values)
p50, p75, p85, p90, p95 = np.percentile(all_chl_values, [50, 75, 85, 90, 95])

print("\nChl-a percentiles inside lake (µg/L):")
print(f"50%: {p50:.2f}, 75%: {p75:.2f}, 85%: {p85:.2f}, 90%: {p90:.2f}, 95%: {p95:.2f}")

CHL_T1 = CHL_T1_FIXED
CHL_T2 = float(np.percentile(all_chl_values, BLOOM_PERCENTILE))  # Option 1

print(f"\nUsing thresholds:")
print(f"  CHL_T1 = {CHL_T1:.2f} µg/L (clean vs nutrient)")
print(f"  CHL_T2 = {CHL_T2:.2f} µg/L (nutrient vs bloom; {BLOOM_PERCENTILE}th percentile)")

# =====================================================
# HELPER: Build class map from stack (Options 1–3 inside)
# =====================================================
def build_classes_from_stack(stack, lake_mask):
    """
    Creates 5 classes (0–4) using:
      - CHL_T1 & CHL_T2 (Option 1)
      - Hybrid bloom rule: high Chl OR (high NDCI & high NDVI) (Option 2)
      - Morphological expansion of bloom (Option 3)
    """
    B2 = stack[0]
    B3 = stack[1]
    B4 = stack[2]    # Red
    B5 = stack[3]    # RE1
    B8 = stack[6]    # NIR

    eps = 1e-9

    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl  = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)

    class_map = np.full(B4.shape, np.nan, dtype="float32")

    # Phumdis first
    dense  = (ndvi >= VEG_DENSE) & lake_mask
    sparse = (ndvi >= VEG_SPARSE) & (ndvi < VEG_DENSE) & lake_mask
    water  = (ndvi < VEG_SPARSE) & lake_mask

    class_map[dense]  = 4   # Dense phumdis
    class_map[sparse] = 3   # Sparse phumdis

    # --- Bloom logic ---

    # 1) Core bloom based on high Chl
    bloom_core = (chl >= CHL_T2) & water

    # 2) Spectral bloom: high NDCI & high NDVI
    bloom_spectral = ((ndci > NDCI_BLOOM_MIN) & (ndvi > NDVI_BLOOM_MIN)) & water

    # Combine
    bloom = bloom_core | bloom_spectral

    # 3) Morphological expansion (Option 3)
    if BLOOM_DILATE_RADIUS > 0:
        bloom = dilation(bloom, disk(BLOOM_DILATE_RADIUS)) & water

    # Set bloom first
    class_map[bloom] = 2

    # Remaining water pixels (not bloom & not phumdis)
    water_remain = water & (~bloom)

    class_map[(chl < CHL_T1) & water_remain] = 0             # Clean
    class_map[(chl >= CHL_T1) & water_remain] = 1            # Nutrient-rich

    return class_map, chl, ndvi, ndci

# =====================================================
# PASS 2: BUILD ML DATASET X, y, years, seasons
# =====================================================
X_list, y_list = [], []
year_list, season_list = [], []

for scene_path in scene_files:
    year, season, base = parse_scene_name(scene_path)
    print(f"\n[PASS 2] Building dataset from: {base}")

    stack, prof = read_stack_reflectance(scene_path)
    n_bands, H, W = stack.shape

    gdf_proj = gdf_lake.to_crs(prof["crs"])
    lake_mask = geometry_mask(
        gdf_proj.geometry,
        transform=prof["transform"],
        invert=True,
        out_shape=(H, W)
    )

    class_map, chl, ndvi, ndci = build_classes_from_stack(stack, lake_mask)

    valid = np.isfinite(class_map)
    if not np.any(valid):
        print(" -> No valid labeled pixels, skipping.")
        continue

    flat_stack = stack.reshape(n_bands, -1).T
    flat_ndvi  = ndvi.flatten()
    flat_ndci  = ndci.flatten()
    flat_chl   = chl.flatten()
    flat_cls   = class_map.flatten()
    flat_valid = valid.flatten()

    idx = np.where(flat_valid)[0]

    feat = np.column_stack([
        flat_stack[idx],
        flat_ndvi[idx],
        flat_ndci[idx],
        flat_chl[idx]
    ])
    labels = flat_cls[idx]

    if len(labels) > SAMPLES_PER_SCENE:
        sel = np.random.choice(len(labels), SAMPLES_PER_SCENE, replace=False)
        feat   = feat[sel]
        labels = labels[sel]

    X_list.append(feat)
    y_list.append(labels)
    year_list.extend([year] * len(labels))
    season_list.extend([season] * len(labels))

# Concatenate all scenes
X = np.concatenate(X_list, axis=0)
y = np.concatenate(y_list, axis=0)
years   = np.array(year_list)
seasons = np.array(season_list)

print("\nFINAL DATASET SHAPES")
print("X:", X.shape)
print("y:", y.shape)

unique, counts = np.unique(y, return_counts=True)
print("\nClass counts:")
for cls, c in zip(unique, counts):
    print(f"  Class {int(cls)} = {c} samples")
print("Seasons:", np.unique(seasons))
print("Years:",   np.unique(years))

# =====================================================
# TRAIN RF + CONFUSION MATRICES
# =====================================================
X_train, X_test, y_train, y_test, seasons_train, seasons_test = train_test_split(
    X, y, seasons, test_size=0.20, random_state=42, stratify=y
)

print("\nTrain size:", X_train.shape, "Test size:", X_test.shape)

rf = RandomForestClassifier(
    n_estimators=300,
    max_depth=20,
    min_samples_split=5,
    min_samples_leaf=3,
    class_weight="balanced",
    random_state=42,
    n_jobs=-1
)
rf.fit(X_train, y_train)

y_pred = rf.predict(X_test)

cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Clean", "Nutrient", "Bloom", "SparsePD", "DensePD"],
            yticklabels=["Clean", "Nutrient", "Bloom", "SparsePD", "DensePD"])
plt.title("Overall Confusion Matrix (All Years & Seasons)")
plt.xlabel("Predicted")
plt.ylabel("Actual")
plt.tight_layout()
plt.show()

print("\nClassification Report:\n")
print(classification_report(y_test, y_pred))

# Per-season confusion matrices
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams

# Global font settings (optional)
rcParams['font.weight'] = 'bold'
rcParams['axes.labelweight'] = 'bold'
rcParams['axes.titleweight'] = 'bold'

CLASS_NAMES = ["Clean", "Nutrient", "Bloom", "SparsePD", "DensePD"]

# Per-season confusion matrices
unique_seasons_test = np.unique(seasons_test)

for season in unique_seasons_test:
    idx = np.where(seasons_test == season)
    cm_s = confusion_matrix(y_test[idx], y_pred[idx])

    plt.figure(figsize=(9, 7))

    ax = sns.heatmap(
        cm_s, annot=True, fmt="d", cmap="viridis",
        xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
        annot_kws={"size": 16, "weight": "bold"}  # Bold annotation numbers
    )

    # Title
    plt.title(f"Confusion Matrix — {season}", fontsize=20, fontweight='bold')

    # Axis labels
    plt.xlabel("Predicted", fontsize=18, fontweight='bold')
    plt.ylabel("Actual", fontsize=18, fontweight='bold')

    # Tick parameters
    ax.tick_params(axis='both', labelsize=14, width=2, length=6)

    # Make tick labels bold
    for label in ax.get_xticklabels():
        label.set_fontweight('bold')
    for label in ax.get_yticklabels():
        label.set_fontweight('bold')

    # Colorbar formatting
    cbar = ax.collections[0].colorbar
    cbar.ax.tick_params(labelsize=14, width=2, length=6)
    for label in cbar.ax.get_yticklabels():
        label.set_fontweight('bold')
    cbar.outline.set_linewidth(2)

    plt.tight_layout()
    plt.show()


# =====================================================
# MAP GENERATION WITH SAME OPTIONS
# =====================================================
def generate_classification_map(scene_path, model):
    """
    Uses trained model to classify ONLY inside the lake shapefile,
    saves GeoTIFF + PNG, and shows the plot.
    """
    year, season, base = parse_scene_name(scene_path)
    save_prefix = f"ML_{year}_{season}"

    print(f"\n--- Generating classification map for {base} ---")

    stack, prof = read_stack_reflectance(scene_path)
    n_bands, H, W = stack.shape

    # Lake mask in this raster's grid
    gdf_proj = gdf_lake.to_crs(prof["crs"])
    lake_mask = geometry_mask(
        gdf_proj.geometry,
        transform=prof["transform"],
        invert=True,
        out_shape=(H, W)
    )

    B4 = stack[2]
    B5 = stack[3]
    B8 = stack[6]
    eps = 1e-9

    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl  = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)

    flat_stack = stack.reshape(n_bands, -1).T
    flat_ndvi  = ndvi.flatten()
    flat_ndci  = ndci.flatten()
    flat_chl   = chl.flatten()
    flat_lake  = lake_mask.flatten()

    X_scene_all = np.column_stack([flat_stack, flat_ndvi, flat_ndci, flat_chl])

    # Predict ONLY inside lake
    class_flat = np.full(H * W, 255, dtype=np.uint8)  # 255 = outside lake / nodata
    idx_lake = np.where(flat_lake)[0]
    if idx_lake.size > 0:
        preds_lake = model.predict(X_scene_all[idx_lake])
        class_flat[idx_lake] = preds_lake.astype(np.uint8)

    class_map = class_flat.reshape(H, W)

    # Save classification GeoTIFF
    out_tif = f"{save_prefix}_RF_Classification.tif"
    out_prof = prof.copy()
    out_prof.update(
        dtype=rasterio.uint8,
        count=1,
        nodata=255        # important: valid for uint8
    )
    with rasterio.open(out_tif, "w", **out_prof) as dst:
        dst.write(class_map, 1)
    print(f"Saved GeoTIFF: {out_tif}")

    # Prepare for plotting: mask out 255 (outside lake)
    plot_map = class_map.astype(float)
    plot_map[plot_map == 255] = np.nan

    # Discrete colormap for classes 0–4
    cmap = mcolors.ListedColormap([
        "#1f78b4",  # 0 Clean water
        "#33a02c",  # 1 Nutrient-rich
        "#ff7f00",  # 2 Bloom
        "#6a3d9a",  # 3 Sparse phumdis
        "#e31a1c"   # 4 Dense phumdis
    ])
    bounds = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    plt.figure(figsize=(8, 8))
    im = plt.imshow(plot_map, cmap=cmap, norm=norm)
    plt.title(f"RF Ecosystem Classification — {year} {season}", fontsize=14)
    plt.axis("off")

    cbar = plt.colorbar(im, boundaries=bounds, ticks=[0, 1, 2, 3, 4])
    cbar.ax.set_yticklabels([
        "0 Clean water",
        "1 Nutrient-rich",
        "2 Bloom",
        "3 Sparse phumdis",
        "4 Dense phumdis"
    ], fontsize=9)

    out_png = f"{save_prefix}_RF_Classification.png"
    plt.savefig(out_png, dpi=300, bbox_inches="tight")
    plt.show()
    print(f"Saved PNG: {out_png}")

    return class_map

# Generate maps for all scenes
for scene_path in scene_files:
    generate_classification_map(scene_path, rf)


# Per-season confusion matrices
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

# Global font settings (optional)
rcParams['font.weight'] = 'bold'
rcParams['axes.labelweight'] = 'bold'
rcParams['axes.titleweight'] = 'bold'

CLASS_NAMES = ["Clean", "Nutrient", "Bloom", "SparsePD", "DensePD"]

cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(9, 7))
ax = sns.heatmap(
    cm,
    annot=True,
    fmt="d",
    cmap="Blues",
    xticklabels=CLASS_NAMES,
    yticklabels=CLASS_NAMES,
    annot_kws={"size": 16, "weight": "bold"}  # bold numbers in cells
)

# Title
ax.set_title(
    "Overall Confusion Matrix (All Years & Seasons)",
    fontsize=20,
    fontweight="bold"
)

# Axis labels
ax.set_xlabel("Predicted", fontsize=18, fontweight="bold")
ax.set_ylabel("Actual", fontsize=18, fontweight="bold")

# Tick style
ax.tick_params(axis="both", labelsize=14, width=2, length=6)
for label in ax.get_xticklabels():
    label.set_fontweight("bold")
for label in ax.get_yticklabels():
    label.set_fontweight("bold")

# Colorbar formatting
cbar = ax.collections[0].colorbar
cbar.ax.tick_params(labelsize=14, width=2, length=6)
for label in cbar.ax.get_yticklabels():
    label.set_fontweight("bold")
cbar.outline.set_linewidth(2)

plt.tight_layout()
plt.show()


def generate_classification_map(scene_path, model):
    import matplotlib.pyplot as plt
    import matplotlib.colors as mcolors
    import matplotlib.image as mpimg
    from pyproj import CRS, Transformer

    year, season, base = parse_scene_name(scene_path)
    save_prefix = f"ML_{year}_{season}"

    print(f"\n--- Generating classification map for {base} ---")

    # ----------------------------
    # LOAD & PREPARE DATA
    # ----------------------------
    stack, prof = read_stack_reflectance(scene_path)
    n_bands, H, W = stack.shape

    gdf_proj = gdf_lake.to_crs(prof["crs"])
    lake_mask = geometry_mask(
        gdf_proj.geometry,
        transform=prof["transform"],
        invert=True,
        out_shape=(H, W)
    )

    B4 = stack[2]
    B5 = stack[3]
    B8 = stack[6]
    eps = 1e-9

    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl  = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)

    # Flatten for prediction
    flat_stack = stack.reshape(n_bands, -1).T
    X_scene_all = np.column_stack([
        flat_stack,
        ndvi.flatten(),
        ndci.flatten(),
        chl.flatten()
    ])
    flat_lake = lake_mask.flatten()

    # ----------------------------
    # CLASSIFY
    # ----------------------------
    class_flat = np.full(H * W, 255, dtype=np.uint8)
    idx_lake = np.where(flat_lake)[0]
    if idx_lake.size > 0:
        class_flat[idx_lake] = model.predict(X_scene_all[idx_lake]).astype(np.uint8)

    class_map = class_flat.reshape(H, W)

    # ----------------------------
    # SAVE GEOTIFF
    # ----------------------------
    out_tif = f"{save_prefix}_RF_Classification.tif"
    out_prof = prof.copy()
    out_prof.update(dtype=rasterio.uint8, count=1, nodata=255)

    with rasterio.open(out_tif, "w", **out_prof) as dst:
        dst.write(class_map, 1)

    print(f"Saved GeoTIFF: {out_tif}")

    # ----------------------------
    # PREP FOR PLOTTING
    # ----------------------------
    plot_map = class_map.astype(float)
    plot_map[plot_map == 255] = np.nan  # outside lake

    # Colors
    cmap = mcolors.ListedColormap([
        "#1f78b4",  # Clean
        "#33a02c",  # Nutrient-rich
        "#ff7f00",  # Bloom
        "#6a3d9a",  # Sparse PD
        "#e31a1c"   # Dense PD
    ])
    bounds = [-0.5, 0.5, 1.5, 2.5, 3.5, 4.5]
    norm = mcolors.BoundaryNorm(bounds, cmap.N)

    # ----------------------------
    # GET LAT/LON EXTENT
    # ----------------------------
    h, w = class_map.shape
    left, bottom, right, top = rasterio.transform.array_bounds(h, w, prof["transform"])

    try:
        src_crs = CRS.from_user_input(prof["crs"])
        transformer = Transformer.from_crs(src_crs, CRS.from_epsg(4326), always_xy=True)
        lon_left, lat_bottom = transformer.transform(left, bottom)
        lon_right, lat_top   = transformer.transform(right, top)
        extent = [lon_left, lon_right, lat_bottom, lat_top]
        use_latlon = True
    except:
        extent = [left, right, bottom, top]
        use_latlon = False

    # ----------------------------
    # FINAL PLOTTING (UPGRADED)
    # ----------------------------
    fig, ax = plt.subplots(figsize=(10, 10), dpi=500)

    im = ax.imshow(plot_map, cmap=cmap, norm=norm,
                   extent=extent, origin="upper")

    # TITLE
    ax.set_title(
        f"Ecosystem Classification — {year} {season}",
        fontsize=22, fontweight="bold"
    )

    # AXIS LABELS
    if use_latlon:
        ax.set_xlabel("Longitude", fontsize=18, fontweight="bold")
        ax.set_ylabel("Latitude", fontsize=18, fontweight="bold")
    else:
        ax.set_xlabel("Easting", fontsize=18, fontweight="bold")
        ax.set_ylabel("Northing", fontsize=18, fontweight="bold")

    # TICKS (BOLD)
    ax.tick_params(axis="both", labelsize=15, width=2, length=6)
    for label in ax.get_xticklabels() + ax.get_yticklabels():
        label.set_fontweight("bold")

    # COLORBAR (BOLD, LARGE)
    cbar = plt.colorbar(im, boundaries=bounds, ticks=[0, 1, 2, 3, 4],
                        fraction=0.035, pad=0.02)
    cbar_labels = ["Clean", "Nutrient-rich", "Bloom", "Sparse PD", "Dense PD"]

    cbar.ax.set_yticklabels(cbar_labels, fontsize=14, fontweight="bold")
    cbar.outline.set_linewidth(2)
    cbar.ax.tick_params(width=2, length=6)

    # NORTH ARROW (OPTIONAL)
    try:
        north_img = mpimg.imread("/content/drive/MyDrive/Pandas/Chl-a/Yearly Cliped_Data/Compass.jpg")
        ax_n = fig.add_axes([0.88, 0.86, 0.09, 0.09])
        ax_n.imshow(north_img)
        ax_n.axis("off")
    except:
        print("Compass image not found — skipping.")

    # SAVE PNG
    out_png = f"{save_prefix}_RF_Classification.png"
    plt.savefig(out_png, dpi=400, bbox_inches="tight")
    plt.show()

    print(f"Saved PNG: {out_png}")

    return class_map

# ==========================================
# ROC–AUC (multiclass, one-vs-rest)
# ==========================================
# Note: This block is correctly placed outside the function.
CLASS_NAMES = ["Clean", "Nutrient", "Bloom", "SparsePD", "DensePD"]
n_classes = len(CLASS_NAMES)

# 1) Get probability scores from RF
y_score = rf.predict_proba(X_test)          # shape: (n_samples, n_classes)

# 2) Binarize y_test for one-vs-rest ROC
y_test_bin = label_binarize(y_test, classes=[0, 1, 2, 3, 4])  # shape: (n_samples, n_classes)

# 3) Compute ROC curve and AUC for each class
fpr = dict()
tpr = dict()
roc_auc = dict()

for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_test_bin[:, i], y_score[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

# 4) Macro-average ROC
#    - Merge all false positive rates
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))

#    - Interpolate all ROC curves at these points & average
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])

mean_tpr /= n_classes

fpr["macro"] = all_fpr
tpr["macro"] = mean_tpr
roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

print(f"\nMacro-average ROC–AUC: {roc_auc['macro']:.3f}")

# 5) Plot ROC curves (paper style)
plt.figure(figsize=(8, 7), dpi=300)

# Macro curve first (thicker)
plt.plot(
    fpr["macro"],
    tpr["macro"],
    label=f"Macro-average (AUC = {roc_auc['macro']:.3f})",
    linewidth=3,
    linestyle='-'
)

# Per-class curves
colors = ["#1f78b4", "#33a02c", "#ff7f00", "#6a3d9a", "#e31a1c"]
for i, color in zip(range(n_classes), colors):
    plt.plot(
        fpr[i], tpr[i],
        lw=1.8,
        label=f"{CLASS_NAMES[i]} (AUC = {roc_auc[i]:.3f})",
        color=color
    )

# Diagonal no-skill line
plt.plot([0, 1], [0, 1], 'k--', lw=1.2)

# Axes limits
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])

# Labels & title (bold)
plt.xlabel("False Positive Rate", fontsize=16, fontweight="bold")
plt.ylabel("True Positive Rate", fontsize=16, fontweight="bold")
plt.title("Multiclass ROC Curves (RF – All Years & Seasons)", fontsize=18, fontweight="bold")

# Ticks bold
plt.xticks(fontsize=14, fontweight="bold")
plt.yticks(fontsize=14, fontweight="bold")

# Legend
leg = plt.legend(loc="lower right", fontsize=11)
for text in leg.get_texts():
    text.set_fontweight("bold")

plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.show()

#CNN

In [None]:



import os
import glob
import math
import random
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
import geopandas as gpd
from skimage.morphology import dilation, disk
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (classification_report, confusion_matrix,
                             accuracy_score, f1_score, precision_score,
                             recall_score, roc_auc_score, roc_curve, auc,
                             cohen_kappa_score, matthews_corrcoef)
from sklearn.preprocessing import label_binarize

# PyTorch imports
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
except Exception as e:
    raise ImportError("PyTorch not found. Install it first, e.g. `pip install torch torchvision`") from e

# ============================
# USER: edit these paths & params
# ============================
SCENES_DIR = r"/content/drive/MyDrive/Pandas/Chl-a/Yearly Cliped_Data/Ml_ AL season"     # e.g. "/content/drive/.../Ml_ AL season"
SHP_PATH   = r"/content/drive/MyDrive/Pandas/Chl-a/Yearly Cliped_Data/SHP"
MAP_SAVE_DIR = r"/content/drive/MyDrive/Pandas/Chl-a/CNN_Output"
MODEL_SAVE_DIR = r"/content/drive/MyDrive/Pandas/Chl-a/CNN_Outputt"
EXCEL_SAVE_DIR = r"/content/drive/MyDrive/Pandas/Chl-a/CNN_Output"
os.makedirs(MAP_SAVE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(EXCEL_SAVE_DIR, exist_ok=True)

PATCH_SIZE = 9                    # odd (e.g., 5,7,9,11)
SAMPLES_PER_CLASS_PER_SCENE = 500 # per scene (reduce if memory issues)
BATCH_SIZE = 64
EPOCHS = 30
RANDOM_STATE = 42
NUM_CLASSES = 5
EXPECTED_BANDS = 11               # B2..B12
SCALE_REFLECTANCE = True          # True if values are 0..10000
BLOOM_PERCENTILE = 85             # to set CHL_T2
CHL_T1_FIXED = 20.0               # clean vs nutrient threshold
VEG_SPARSE = 0.20
VEG_DENSE  = 0.40
NDCI_BLOOM_MIN = 0.15
NDVI_BLOOM_MIN = 0.25
BLOOM_DILATE_RADIUS = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

# ============================
# Helper: read stack reflectance
# ============================
def read_stack_reflectance(path):
    """Read stacked multi-band GeoTIFF (bands x H x W). Convert to float32 reflectance 0..1 if needed."""
    with rasterio.open(path) as src:
        arr = src.read().astype("float32")  # shape (bands, H, W)
        prof = src.profile.copy()
    finite = np.isfinite(arr)
    if SCALE_REFLECTANCE and finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr = np.where(finite, arr / 10000.0, np.nan)
    arr = np.clip(arr, 0.0, 1.5)
    return arr, prof

# ============================
# Helper: parse scene file name
# ============================
def parse_scene_name(path):
    base = os.path.basename(path).replace(".tif","")
    parts = base.split("_",1)
    if len(parts)==2:
        year, season = parts
    else:
        year, season = "Unknown", base
    return year, season, base

# ============================
# Helper: compute class map (same rules as RF pipeline)
# ============================
def build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2):
    # stack: (bands, H, W); expects RE1 at index 3, Red at index 2, NIR at index 6 if B2..B12
    B2 = stack[0]; B3 = stack[1]; B4 = stack[2]; B5 = stack[3]; B8 = stack[6]
    eps = 1e-9
    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)
    H,W = B4.shape
    class_map = np.full((H,W), np.nan, dtype="float32")

    dense  = (ndvi >= VEG_DENSE) & lake_mask
    sparse = (ndvi >= VEG_SPARSE) & (ndvi < VEG_DENSE) & lake_mask
    water  = (ndvi < VEG_SPARSE) & lake_mask

    class_map[dense]  = 4
    class_map[sparse] = 3

    bloom_core = (chl >= CHL_T2) & water
    bloom_spectral = ((ndci > NDCI_BLOOM_MIN) & (ndvi > NDVI_BLOOM_MIN)) & water
    bloom = bloom_core | bloom_spectral
    if BLOOM_DILATE_RADIUS > 0:
        bloom = dilation(bloom, disk(BLOOM_DILATE_RADIUS)) & water
    class_map[bloom] = 2

    water_remain = water & (~bloom)
    class_map[(chl < CHL_T1) & water_remain] = 0
    class_map[(chl >= CHL_T1) & water_remain] = 1

    return class_map, chl, ndvi, ndci

# ============================
# PASS 1: scan scenes to collect CHL percentiles -> CHL_T2
# ============================
scene_files = sorted(glob.glob(os.path.join(SCENES_DIR, "*.tif")))
if len(scene_files)==0:
    raise RuntimeError("No .tif scenes found in SCENES_DIR. Edit SCENES_DIR path.")

gdf_lake = gpd.read_file(SHP_PATH)
print("Found scenes:", len(scene_files), "Shapefile rows:", len(gdf_lake))

all_chl = []
for scene_path in scene_files:
    print("Collecting chl from:", os.path.basename(scene_path))
    stack, prof = read_stack_reflectance(scene_path)
    H = stack.shape[1]; W = stack.shape[2]
    gdf_proj = gdf_lake.to_crs(prof['crs'])
    lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
    _, chl, _, _ = build_classes_from_stack(stack, lake_mask, CHL_T1_FIXED, CHL_T1_FIXED+10)  # CHL_T2 placeholder
    valid = lake_mask & np.isfinite(chl)
    if np.any(valid):
        all_chl.append(chl[valid])
if len(all_chl)==0:
    raise RuntimeError("No valid chl values inside lake for any scene.")
all_chl = np.concatenate(all_chl)
p50,p75,p85,p90,p95 = np.percentile(all_chl, [50,75,85,90,95])
CHL_T1 = CHL_T1_FIXED
CHL_T2 = float(np.percentile(all_chl, BLOOM_PERCENTILE))
print(f"Chl percentiles: 50%={p50:.2f},75%={p75:.2f},85%={p85:.2f},90%={p90:.2f},95%={p95:.2f}")
print(f"Using CHL_T1={CHL_T1:.2f}, CHL_T2({BLOOM_PERCENTILE}th)={CHL_T2:.2f}")

# ============================
# Patch extraction functions
# ============================
def extract_patches_from_scene(stack, class_map, lake_mask, patch_size=9, samples_per_class=500):
    half = patch_size//2
    bands, H, W = stack.shape
    pad_width = ((0,0),(half,half),(half,half))
    stack_p = np.pad(stack, pad_width=pad_width, mode='reflect')
    X = []
    y = []
    flat_cls = class_map.flatten()
    flat_lake = lake_mask.flatten()
    for cls in range(NUM_CLASSES):
        idxs = np.where((flat_cls==cls) & (flat_lake))[0]
        if idxs.size==0:
            continue
        # sample up to requested
        sel = idxs if idxs.size <= samples_per_class else np.random.choice(idxs, samples_per_class, replace=False)
        for ind in sel:
            r = ind // W; c = ind % W
            rp = r + half; cp = c + half
            patch = stack_p[:, rp-half:rp+half+1, cp-half:cp+half+1]
            patch = np.transpose(patch, (1,2,0))  # (p,p,bands)
            X.append(patch)
            y.append(cls)
    if len(X)==0:
        return np.empty((0,patch_size,patch_size,bands)), np.empty((0,), dtype=int)
    X = np.stack(X, axis=0).astype('float32')
    y = np.array(y, dtype=int)
    return X,y

def build_dataset_from_scenes(scene_files, gdf_lake, patch_size=9, samples_per_class=500):
    X_list=[]; y_list=[]
    for scene_path in scene_files:
        print("Scene:", os.path.basename(scene_path))
        stack, prof = read_stack_reflectance(scene_path)
        H = stack.shape[1]; W = stack.shape[2]
        gdf_proj = gdf_lake.to_crs(prof['crs'])
        lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
        class_map, chl, ndvi, ndci = build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2)
        Xp, yp = extract_patches_from_scene(stack, class_map, lake_mask, patch_size=patch_size, samples_per_class=samples_per_class)
        if Xp.shape[0]>0:
            X_list.append(Xp); y_list.append(yp)
    if not X_list:
        raise RuntimeError("No patches extracted from scenes. Check masks/thresholds.")
    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    print("Patches total:", X_all.shape, "Label counts:", np.unique(y_all, return_counts=True))
    return X_all, y_all

# ============================
# Build dataset (patches)
# ============================
X, y = build_dataset_from_scenes(scene_files, gdf_lake, patch_size=PATCH_SIZE, samples_per_class=SAMPLES_PER_CLASS_PER_SCENE)

# If bands != expected, warn (but continue)
bands_found = X.shape[-1]
if bands_found != EXPECTED_BANDS:
    print(f"Warning: patches have {bands_found} bands (EXPECTED_BANDS={EXPECTED_BANDS}). If order differs, update code.")

# Train/val/test split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=RANDOM_STATE, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=RANDOM_STATE, stratify=y_temp)
print("Train/Val/Test sizes:", X_train.shape[0], X_val.shape[0], X_test.shape[0])

# ============================
# PyTorch Dataset & Dataloader
# ============================
class PatchDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = X
        self.y = y
        self.transform = transform
    def __len__(self):
        return self.X.shape[0]
    def __getitem__(self, idx):
        x = self.X[idx]  # (p,p,bands)
        # transpose to (bands, p, p) for PyTorch
        x = np.transpose(x, (2,0,1)).astype('float32')
        if self.transform:
            x = self.transform(x)
        y = int(self.y[idx])
        return torch.from_numpy(x), torch.tensor(y, dtype=torch.long)

train_ds = PatchDataset(X_train, y_train)
val_ds   = PatchDataset(X_val, y_val)
test_ds  = PatchDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# ============================
# Compute class weights for loss
# ============================
cls_vals = np.unique(y_train)
weights = compute_class_weight(class_weight='balanced', classes=cls_vals, y=y_train)
class_weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE)
print("Class weights:", dict(zip(cls_vals, weights)))

# ============================
# Define CNN model (PyTorch)
# ============================
class SimpleCNN(nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64,64,kernel_size=3,padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2)
        self.drop1 = nn.Dropout2d(0.25)

        self.conv3 = nn.Conv2d(64,128,kernel_size=3,padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128,128,kernel_size=3,padding=1)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2)
        self.drop2 = nn.Dropout2d(0.35)

        # global pool
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(128, 256)
        self.bnfc = nn.BatchNorm1d(256)
        self.dropfc = nn.Dropout(0.4)
        self.out = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.drop1(x)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = self.drop2(x)

        x = self.gap(x)   # shape (B,128,1,1)
        x = x.view(x.size(0), -1)
        x = F.relu(self.bnfc(self.fc1(x)))
        x = self.dropfc(x)
        x = self.out(x)
        return x

model = SimpleCNN(in_channels=bands_found, num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# ============================
# Training loop (simple early stopping)
# ============================
best_val_acc = 0.0
patience = 6
patience_counter = 0
history = {"train_loss":[], "val_loss":[], "train_acc":[], "val_acc":[]}

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0; correct=0; total=0
    for xb, yb in train_loader:
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
        preds = out.argmax(dim=1)
        correct += (preds==yb).sum().item()
        total += xb.size(0)
    train_loss = running_loss/total
    train_acc = correct/total

    # validation
    model.eval()
    vloss=0.0; vcorrect=0; vtotal=0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            out = model(xb)
            loss = criterion(out, yb)
            vloss += loss.item() * xb.size(0)
            preds = out.argmax(dim=1)
            vcorrect += (preds==yb).sum().item()
            vtotal += xb.size(0)
    val_loss = vloss/vtotal
    val_acc = vcorrect/vtotal

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch}/{EPOCHS}  train_loss={train_loss:.4f} train_acc={train_acc:.4f}  val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

    # early stopping & save best
    if val_acc > best_val_acc + 1e-5:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "cnn_best.pth"))
        patience_counter = 0
        print("  Saved best model.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping.")
            break

# load best
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_DIR, "cnn_best.pth"), map_location=DEVICE))
model.eval()

# ============================
# Evaluate on test set
# ============================
y_true = []
y_pred = []
y_proba = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(DEVICE)
        out = model(xb)
        probs = F.softmax(out, dim=1).cpu().numpy()
        preds = out.argmax(dim=1).cpu().numpy()
        y_true.extend(yb.numpy().tolist())
        y_pred.extend(preds.tolist())
        y_proba.extend(probs.tolist())
y_true = np.array(y_true); y_pred = np.array(y_pred); y_proba = np.vstack(y_proba)

# Metrics
acc  = accuracy_score(y_true, y_pred)
f1_m = f1_score(y_true, y_pred, average='macro')
f1_w = f1_score(y_true, y_pred, average='weighted')
prec = precision_score(y_true, y_pred, average='macro')
rec  = recall_score(y_true, y_pred, average='macro')
kappa = cohen_kappa_score(y_true, y_pred)
mcc = matthews_corrcoef(y_true, y_pred)
roc_macro = roc_auc_score(label_binarize(y_true, classes=list(range(NUM_CLASSES))), y_proba, multi_class='ovr', average='macro')

print("\n=== CNN Test-set Performance ===")
print(f"Accuracy          : {acc:.3f}")
print(f"F1-score (macro)  : {f1_m:.3f}")
print(f"F1-score (weighted): {f1_w:.3f}")
print(f"Precision (macro) : {prec:.3f}")
print(f"Recall (macro)    : {rec:.3f}")
print(f"Cohen's Kappa     : {kappa:.3f}")
print(f"MCC               : {mcc:.3f}")
print(f"ROC–AUC (macro)   : {roc_macro:.3f}")

print("\nClassification Report (per class):")
print(classification_report(y_true, y_pred, target_names=["Clean","Nutrient","Bloom","SparsePD","DensePD"]))

# ============================
# Save performance to Excel
# ============================
summary_metrics = {
    "Model": ["CNN"],
    "Accuracy": [acc],
    "F1_macro": [f1_m],
    "F1_weighted": [f1_w],
    "Precision_macro": [prec],
    "Recall_macro": [rec],
    "Cohen_Kappa": [kappa],
    "MCC": [mcc],
    "ROC_AUC_macro": [roc_macro]
}
df_summary = pd.DataFrame(summary_metrics)
df_per_class = pd.DataFrame(classification_report(y_true, y_pred, target_names=["Clean","Nutrient","Bloom","SparsePD","DensePD"], output_dict=True)).transpose()

excel_out = os.path.join(EXCEL_SAVE_DIR, "CNN_Performance_Metrics.xlsx")
with pd.ExcelWriter(excel_out, engine="openpyxl") as writer:
    df_summary.to_excel(writer, sheet_name="Summary", index=False)
    df_per_class.to_excel(writer, sheet_name="Per_Class")
print("Saved performance Excel to:", excel_out)

# ============================
# Styled confusion matrix (with cell lines)
# ============================
CLASS_NAMES = ["Clean","Nutrient","Bloom","SparsePD","DensePD"]
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(9,7))
ax = sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                 xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                 annot_kws={"size":16, "weight":"bold"},
                 linewidths=2, linecolor="black")
ax.set_title("CNN — Confusion Matrix (Test)", fontsize=20, fontweight="bold")
ax.set_xlabel("Predicted", fontsize=18, fontweight="bold"); ax.set_ylabel("Actual", fontsize=18, fontweight="bold")
ax.tick_params(axis="both", labelsize=14, width=2, length=6)
for lbl in ax.get_xticklabels()+ax.get_yticklabels(): lbl.set_fontweight("bold")
cbar = ax.collections[0].colorbar; cbar.ax.tick_params(labelsize=14); cbar.outline.set_linewidth(2)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "CNN_Confusion_Matrix.png"), dpi=600, bbox_inches="tight")
plt.show()

# ============================
# ROC curves plot (multiclass)
# ============================
n_classes = NUM_CLASSES
y_bin = label_binarize(y_true, classes=list(range(n_classes)))
fpr = dict(); tpr = dict(); roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_bin[:,i], y_proba[:,i])
    roc_auc[i] = auc(fpr[i], tpr[i])
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
fpr["macro"] = all_fpr; tpr["macro"] = mean_tpr; roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

plt.figure(figsize=(9,8), dpi=300)
plt.plot(fpr["macro"], tpr["macro"], label=f"Macro-average (AUC = {roc_auc['macro']:.3f})", linewidth=4, color="black")
line_styles = ["-", "--", "-.", ":", (0,(3,1,1,1))]
colors = ["#1f78b4","#33a02c","#ff7f00","#6a3d9a","#e31a1c"]
for i,(ls,color) in enumerate(zip(line_styles, colors)):
    plt.plot(fpr[i], tpr[i], lw=2.5, linestyle=ls, color=color, label=f"{CLASS_NAMES[i]} (AUC={roc_auc[i]:.3f})")
plt.plot([0,1],[0,1],'k--', lw=1.5)
plt.xlim([0,1]); plt.ylim([0,1.05])
plt.xlabel("False Positive Rate", fontsize=18, fontweight="bold")
plt.ylabel("True Positive Rate", fontsize=18, fontweight="bold")
plt.title("CNN — Multiclass ROC Curves", fontsize=20, fontweight="bold")
plt.xticks(fontsize=14); plt.yticks(fontsize=14)
leg = plt.legend(loc="lower right", fontsize=12)
for t in leg.get_texts(): t.set_fontweight("bold")
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "CNN_ROC_Curves.png"), dpi=600, bbox_inches="tight")
plt.show()

# ============================
# Predict and save classification maps per scene (batch-patch prediction)
# ============================
def predict_map_with_cnn(scene_path, model, patch_size=PATCH_SIZE, batch_size=4096, save_dir=MAP_SAVE_DIR):
    year, season, base = parse_scene_name(scene_path)
    stack, prof = read_stack_reflectance(scene_path)
    bands, H, W = stack.shape
    gdf_proj = gdf_lake.to_crs(prof['crs'])
    lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))

    half = patch_size//2
    pad_width = ((0,0),(half,half),(half,half))
    stack_p = np.pad(stack, pad_width=pad_width, mode='reflect')
    class_flat = np.full(H*W, 255, dtype=np.uint8)
    idx_lake = np.where(lake_mask.flatten())[0]
    print("Predicting for lake pixels:", idx_lake.size)
    n = idx_lake.size
    for i in range(0, n, batch_size):
        sel = idx_lake[i:i+batch_size]
        patches = []
        for ind in sel:
            r = ind // W; c = ind % W
            rp = r + half; cp = c + half
            p = stack_p[:, rp-half:rp+half+1, cp-half:cp+half+1]
            p = np.transpose(p, (1,2,0)).astype('float32')
            patches.append(p)
        Xb = np.stack(patches, axis=0)
        Xb_t = torch.from_numpy(np.transpose(Xb, (0,3,1,2))).to(DEVICE)
        with torch.no_grad():
            out = model(Xb_t)
            probs = F.softmax(out, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1).astype(np.uint8)
        class_flat[sel] = preds
    class_map = class_flat.reshape(H,W)

    # save GeoTIFF
    out_tif = os.path.join(save_dir, f"{year}_{season}_CNN_Map.tif")
    prof2 = prof.copy(); prof2.update(dtype=rasterio.uint8, count=1, nodata=255)
    with rasterio.open(out_tif, "w", **prof2) as dst:
        dst.write(class_map, 1)
    print("Saved CNN GeoTIFF:", out_tif)

    # save PNG (clean map only)
    plot_map = class_map.astype(float); plot_map[plot_map==255]=np.nan
    cmap = mcolors.ListedColormap(colors)
    bounds = [-0.5,0.5,1.5,2.5,3.5,4.5]; norm = mcolors.BoundaryNorm(bounds, cmap.N)
    fig, ax = plt.subplots(figsize=(10,10), dpi=300)
    ax.imshow(plot_map, cmap=cmap, norm=norm, origin='upper')
    ax.set_title(f"{year} {season} (CNN)", fontsize=22, fontweight='bold')
    ax.axis('off')
    out_png = os.path.join(save_dir, f"{year}_{season}_CNN_Map.png")
    fig.savefig(out_png, dpi=400, bbox_inches='tight')
    plt.show()
    print("Saved CNN PNG:", out_png)
    return class_map

# run prediction maps for all scenes
for scene in scene_files:
    predict_map_with_cnn(scene, model, patch_size=PATCH_SIZE, batch_size=4096, save_dir=MAP_SAVE_DIR)

print("All done. Maps in:", MAP_SAVE_DIR)
print("Models in:", MODEL_SAVE_DIR)
print("Excel in:", EXCEL_SAVE_DIR)


3D CNN

In [None]:
# Full runnable cell: 3D CNN pipeline for Sentinel-2 patch classification
# Edit paths and parameters below before running.

import os
import glob
import math
import random
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
import geopandas as gpd
from skimage.morphology import dilation, disk
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (classification_report, confusion_matrix,
                             accuracy_score, f1_score, precision_score,
                             recall_score, roc_auc_score, roc_curve, auc,
                             cohen_kappa_score, matthews_corrcoef)
from sklearn.preprocessing import label_binarize

# PyTorch imports
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
    from torchvision import transforms
except Exception as e:
    raise ImportError("PyTorch not found. Install it first, e.g. `pip install torch torchvision`") from e

# ============================
# USER: edit these paths & params
# ============================
SCENES_DIR = r"/content/drive/MyDrive/_ AL season"     # e.g. "/content/drive/.../Ml_ AL season"
SHP_PATH   = r"/content/drive/MyDrive/HP"
MAP_SAVE_DIR = r"/content/drive/MyDrive/t"
MODEL_SAVE_DIR = r"/content/drive/MyDrive/Pandas/t"
EXCEL_SAVE_DIR = r"/content/drive/MyDrive/Pandas/t"
os.makedirs(MAP_SAVE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(EXCEL_SAVE_DIR, exist_ok=True)

PATCH_SIZE = 9                    # odd (e.g., 5,7,9,11)
SAMPLES_PER_CLASS_PER_SCENE = 500 # per scene (reduce if memory issues)
BATCH_SIZE = 32
EPOCHS = 30
RANDOM_STATE = 42
NUM_CLASSES = 5
# Set expected bands according to your stacked TIFF order. For B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12 -> 10 bands
EXPECTED_BANDS = 10
SCALE_REFLECTANCE = True          # True if values are 0..10000 in TIFF
BLOOM_PERCENTILE = 85             # to set CHL_T2
CHL_T1_FIXED = 20.0               # clean vs nutrient threshold
VEG_SPARSE = 0.20
VEG_DENSE  = 0.40
NDCI_BLOOM_MIN = 0.15
NDVI_BLOOM_MIN = 0.25
BLOOM_DILATE_RADIUS = 2

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

# color palette used for plotting & map (five classes)
CLASS_NAMES = ["Clean","Nutrient","Bloom","SparsePD","DensePD"]
PLOT_COLORS = ["#1f78b4","#33a02c","#ff7f00","#6a3d9a","#e31a1c"]

# ============================
# Helper: read stack reflectance
# ============================
def read_stack_reflectance(path):
    """Read stacked multi-band GeoTIFF (bands x H x W). Convert to float32 reflectance 0..1 if needed."""
    with rasterio.open(path) as src:
        arr = src.read().astype("float32")  # shape (bands, H, W)
        prof = src.profile.copy()
    finite = np.isfinite(arr)
    if SCALE_REFLECTANCE and finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr = np.where(finite, arr / 10000.0, np.nan)
    arr = np.clip(arr, 0.0, 1.5)
    return arr, prof

# ============================
# Helper: parse scene file name
# ============================
def parse_scene_name(path):
    base = os.path.basename(path).replace(".tif","")
    parts = base.split("_",1)
    if len(parts)==2:
        year, season = parts
    else:
        year, season = "Unknown", base
    return year, season, base

# ============================
# Helper: compute class map (same rules as RF pipeline)
# ============================
def build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2):
    # stack: (bands, H, W); expects B2 index 0, B3 index 1, B4 index 2, B5 index 3, B8 index 6 for the standard ordering used earlier
    # Adjust indexes here if your band order differs.
    B2 = stack[0]; B3 = stack[1]; B4 = stack[2]; B5 = stack[3]; B8 = stack[6]
    eps = 1e-9
    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    # empirical chlorophyll estimation (same formula you used)
    chl = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)
    H,W = B4.shape
    class_map = np.full((H,W), np.nan, dtype="float32")

    dense  = (ndvi >= VEG_DENSE) & lake_mask
    sparse = (ndvi >= VEG_SPARSE) & (ndvi < VEG_DENSE) & lake_mask
    water  = (ndvi < VEG_SPARSE) & lake_mask

    class_map[dense]  = 4
    class_map[sparse] = 3

    bloom_core = (chl >= CHL_T2) & water
    bloom_spectral = ((ndci > NDCI_BLOOM_MIN) & (ndvi > NDVI_BLOOM_MIN)) & water
    bloom = bloom_core | bloom_spectral
    if BLOOM_DILATE_RADIUS > 0:
        bloom = dilation(bloom, disk(BLOOM_DILATE_RADIUS)) & water
    class_map[bloom] = 2

    water_remain = water & (~bloom)
    class_map[(chl < CHL_T1) & water_remain] = 0
    class_map[(chl >= CHL_T1) & water_remain] = 1

    return class_map, chl, ndvi, ndci

# ============================
# PASS 1: scan scenes to collect CHL percentiles -> CHL_T2
# ============================
scene_files = sorted(glob.glob(os.path.join(SCENES_DIR, "*.tif")))
if len(scene_files)==0:
    raise RuntimeError("No .tif scenes found in SCENES_DIR. Edit SCENES_DIR path.")

gdf_lake = gpd.read_file(SHP_PATH)
print("Found scenes:", len(scene_files), "Shapefile rows:", len(gdf_lake))

all_chl = []
for scene_path in scene_files:
    print("Collecting chl from:", os.path.basename(scene_path))
    stack, prof = read_stack_reflectance(scene_path)
    H = stack.shape[1]; W = stack.shape[2]
    gdf_proj = gdf_lake.to_crs(prof['crs'])
    lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
    _, chl, _, _ = build_classes_from_stack(stack, lake_mask, CHL_T1_FIXED, CHL_T1_FIXED+10)  # CHL_T2 placeholder
    valid = lake_mask & np.isfinite(chl)
    if np.any(valid):
        all_chl.append(chl[valid])
if len(all_chl)==0:
    raise RuntimeError("No valid chl values inside lake for any scene.")
all_chl = np.concatenate(all_chl)
p50,p75,p85,p90,p95 = np.percentile(all_chl, [50,75,85,90,95])
CHL_T1 = CHL_T1_FIXED
CHL_T2 = float(np.percentile(all_chl, BLOOM_PERCENTILE))
print(f"Chl percentiles: 50%={p50:.2f},75%={p75:.2f},85%={p85:.2f},90%={p90:.2f},95%={p95:.2f}")
print(f"Using CHL_T1={CHL_T1:.2f}, CHL_T2({BLOOM_PERCENTILE}th)={CHL_T2:.2f}")

# ============================
# Patch extraction functions
# ============================
def extract_patches_from_scene(stack, class_map, lake_mask, patch_size=9, samples_per_class=500):
    half = patch_size//2
    bands, H, W = stack.shape
    pad_width = ((0,0),(half,half),(half,half))
    stack_p = np.pad(stack, pad_width=pad_width, mode='reflect')
    X = []
    y = []
    flat_cls = class_map.flatten()
    flat_lake = lake_mask.flatten()
    for cls in range(NUM_CLASSES):
        idxs = np.where((flat_cls==cls) & (flat_lake))[0]
        if idxs.size==0:
            continue
        # sample up to requested
        sel = idxs if idxs.size <= samples_per_class else np.random.choice(idxs, samples_per_class, replace=False)
        for ind in sel:
            r = ind // W; c = ind % W
            rp = r + half; cp = c + half
            patch = stack_p[:, rp-half:rp+half+1, cp-half:cp+half+1]
            patch = np.transpose(patch, (1,2,0))  # (p,p,bands)
            X.append(patch)
            y.append(cls)
    if len(X)==0:
        return np.empty((0,patch_size,patch_size,bands)), np.empty((0,), dtype=int)
    X = np.stack(X, axis=0).astype('float32')
    y = np.array(y, dtype=int)
    return X,y

def build_dataset_from_scenes(scene_files, gdf_lake, patch_size=9, samples_per_class=500):
    X_list=[]; y_list=[]
    for scene_path in scene_files:
        print("Scene:", os.path.basename(scene_path))
        stack, prof = read_stack_reflectance(scene_path)
        H = stack.shape[1]; W = stack.shape[2]
        gdf_proj = gdf_lake.to_crs(prof['crs'])
        lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
        class_map, chl, ndvi, ndci = build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2)
        Xp, yp = extract_patches_from_scene(stack, class_map, lake_mask, patch_size=patch_size, samples_per_class=samples_per_class)
        if Xp.shape[0]>0:
            X_list.append(Xp); y_list.append(yp)
    if not X_list:
        raise RuntimeError("No patches extracted from scenes. Check masks/thresholds.")
    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    print("Patches total:", X_all.shape, "Label counts:", np.unique(y_all, return_counts=True))
    return X_all, y_all

# ============================
# Build dataset (patches)
# ============================
X, y = build_dataset_from_scenes(scene_files, gdf_lake, patch_size=PATCH_SIZE, samples_per_class=SAMPLES_PER_CLASS_PER_SCENE)

# If bands != expected, warn (but continue)
bands_found = X.shape[-1]
if bands_found != EXPECTED_BANDS:
    print(f"Warning: patches have {bands_found} bands (EXPECTED_BANDS={EXPECTED_BANDS}). If order differs, update code.")

# Train/val/test split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=RANDOM_STATE, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=RANDOM_STATE, stratify=y_temp)
print("Train/Val/Test sizes:", X_train.shape[0], X_val.shape[0], X_test.shape[0])

# ============================
# PyTorch Dataset & Dataloader (3D-ready)
# ============================
class PatchDataset3D(Dataset):
    def __init__(self, X, y, transform=None):
        """
        X: numpy array shape (N, p, p, bands)
        We'll convert to torch tensor shape (1, bands, p, p) so that
        Conv3D input becomes (batch, 1, D=bands, H=p, W=p)
        """
        self.X = X
        self.y = y
        self.transform = transform

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        x = self.X[idx]  # (p, p, bands)
        # transpose to (bands, p, p)
        x = np.transpose(x, (2, 0, 1)).astype('float32')  # (bands, H, W)
        # add channel dim -> (1, bands, H, W)
        x = np.expand_dims(x, axis=0)
        if self.transform:
            x = self.transform(x)
        # convert to torch tensor (float32)
        return torch.from_numpy(x), torch.tensor(int(self.y[idx]), dtype=torch.long)

train_ds = PatchDataset3D(X_train, y_train)
val_ds   = PatchDataset3D(X_val, y_val)
test_ds  = PatchDataset3D(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# ============================
# Compute class weights for loss
# ============================
cls_vals = np.unique(y_train)
weights = compute_class_weight(class_weight='balanced', classes=cls_vals, y=y_train)
class_weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE)
print("Class weights:", dict(zip(cls_vals, weights)))

# ============================
# Define 3D CNN model (PyTorch)
# ============================
class Simple3DCNN(nn.Module):
    def __init__(self, num_bands, num_classes):
        """
        Input shape: (batch, 1, D=num_bands, H=patch, W=patch)
        """
        super().__init__()
        in_ch = 1  # single input channel (depth = spectral)
        self.conv1 = nn.Conv3d(in_ch, 32, kernel_size=(3,3,3), padding=(1,1,1))
        self.bn1 = nn.BatchNorm3d(32)
        self.conv2 = nn.Conv3d(32, 32, kernel_size=(3,3,3), padding=(1,1,1))
        self.bn2 = nn.BatchNorm3d(32)
        # pool spatially only (preserve spectral resolution)
        self.pool1 = nn.MaxPool3d(kernel_size=(1,2,2))

        self.conv3 = nn.Conv3d(32, 64, kernel_size=(3,3,3), padding=(1,1,1))
        self.bn3 = nn.BatchNorm3d(64)
        self.conv4 = nn.Conv3d(64, 64, kernel_size=(3,3,3), padding=(1,1,1))
        self.bn4 = nn.BatchNorm3d(64)
        self.pool2 = nn.MaxPool3d(kernel_size=(1,2,2))

        self.conv5 = nn.Conv3d(64, 128, kernel_size=(3,3,3), padding=(1,1,1))
        self.bn5 = nn.BatchNorm3d(128)

        # global pooling to (B,128,1,1,1)
        self.gap = nn.AdaptiveAvgPool3d((1,1,1))

        self.fc1 = nn.Linear(128, 256)
        self.bnfc = nn.BatchNorm1d(256)
        self.dropfc = nn.Dropout(0.4)
        self.out = nn.Linear(256, num_classes)

    def forward(self, x):
        # x: (B,1,D,H,W)
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = F.dropout3d(x, p=0.15, training=self.training)

        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = F.dropout3d(x, p=0.25, training=self.training)

        x = F.relu(self.bn5(self.conv5(x)))

        x = self.gap(x)  # (B,128,1,1,1)
        x = x.view(x.size(0), -1)  # (B,128)

        x = F.relu(self.bnfc(self.fc1(x)))
        x = self.dropfc(x)
        x = self.out(x)
        return x

model = Simple3DCNN(num_bands=bands_found, num_classes=NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# ============================
# Training loop (same logic)
# ============================
best_val_acc = 0.0
patience = 6
patience_counter = 0
history = {"train_loss":[], "val_loss":[], "train_acc":[], "val_acc":[]}

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0; correct=0; total=0
    for xb, yb in train_loader:
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
        preds = out.argmax(dim=1)
        correct += (preds==yb).sum().item()
        total += xb.size(0)
    train_loss = running_loss/total
    train_acc = correct/total

    # validation
    model.eval()
    vloss=0.0; vcorrect=0; vtotal=0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            out = model(xb)
            loss = criterion(out, yb)
            vloss += loss.item() * xb.size(0)
            preds = out.argmax(dim=1)
            vcorrect += (preds==yb).sum().item()
            vtotal += xb.size(0)
    val_loss = vloss/vtotal
    val_acc = vcorrect/vtotal

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch}/{EPOCHS}  train_loss={train_loss:.4f} train_acc={train_acc:.4f}  val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

    # early stopping & save best
    if val_acc > best_val_acc + 1e-5:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "cnn3d_best.pth"))
        patience_counter = 0
        print("  Saved best model.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping.")
            break

# load best
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_DIR, "cnn3d_best.pth"), map_location=DEVICE))
model.eval()

# ============================
# Evaluate on test set
# ============================
y_true = []
y_pred = []
y_proba = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(DEVICE)
        out = model(xb)
        probs = F.softmax(out, dim=1).cpu().numpy()
        preds = out.argmax(dim=1).cpu().numpy()
        y_true.extend(yb.numpy().tolist())
        y_pred.extend(preds.tolist())
        y_proba.extend(probs.tolist())
y_true = np.array(y_true); y_pred = np.array(y_pred); y_proba = np.vstack(y_proba)

# Metrics
acc  = accuracy_score(y_true, y_pred)
f1_m = f1_score(y_true, y_pred, average='macro')
f1_w = f1_score(y_true, y_pred, average='weighted')
prec = precision_score(y_true, y_pred, average='macro')
rec  = recall_score(y_true, y_pred, average='macro')
kappa = cohen_kappa_score(y_true, y_pred)
mcc = matthews_corrcoef(y_true, y_pred)
roc_macro = roc_auc_score(label_binarize(y_true, classes=list(range(NUM_CLASSES))), y_proba, multi_class='ovr', average='macro')

print("\n=== 3D-CNN Test-set Performance ===")
print(f"Accuracy          : {acc:.3f}")
print(f"F1-score (macro)  : {f1_m:.3f}")
print(f"F1-score (weighted): {f1_w:.3f}")
print(f"Precision (macro) : {prec:.3f}")
print(f"Recall (macro)    : {rec:.3f}")
print(f"Cohen's Kappa     : {kappa:.3f}")
print(f"MCC               : {mcc:.3f}")
print(f"ROC–AUC (macro)   : {roc_macro:.3f}")

print("\nClassification Report (per class):")
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))

# ============================
# Save performance to Excel
# ============================
summary_metrics = {
    "Model": ["3D_CNN"],
    "Accuracy": [acc],
    "F1_macro": [f1_m],
    "F1_weighted": [f1_w],
    "Precision_macro": [prec],
    "Recall_macro": [rec],
    "Cohen_Kappa": [kappa],
    "MCC": [mcc],
    "ROC_AUC_macro": [roc_macro]
}
df_summary = pd.DataFrame(summary_metrics)
df_per_class = pd.DataFrame(classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)).transpose()

excel_out = os.path.join(EXCEL_SAVE_DIR, "CNN3D_Performance_Metrics.xlsx")
with pd.ExcelWriter(excel_out, engine="openpyxl") as writer:
    df_summary.to_excel(writer, sheet_name="Summary", index=False)
    df_per_class.to_excel(writer, sheet_name="Per_Class")
print("Saved performance Excel to:", excel_out)

# ============================
# Styled confusion matrix (with cell lines)
# ============================
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(9,7))
ax = sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                 xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                 annot_kws={"size":14, "weight":"bold"},
                 linewidths=2, linecolor="black")
ax.set_title("3D-CNN — Confusion Matrix (Test)", fontsize=20, fontweight="bold")
ax.set_xlabel("Predicted", fontsize=16, fontweight="bold"); ax.set_ylabel("Actual", fontsize=16, fontweight="bold")
ax.tick_params(axis="both", labelsize=12, width=2, length=6)
for lbl in ax.get_xticklabels()+ax.get_yticklabels(): lbl.set_fontweight("bold")
cbar = ax.collections[0].colorbar; cbar.ax.tick_params(labelsize=12); cbar.outline.set_linewidth(2)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "CNN3D_Confusion_Matrix.png"), dpi=600, bbox_inches="tight")
plt.show()

# ============================
# ROC curves plot (multiclass)
# ============================
n_classes = NUM_CLASSES
y_bin = label_binarize(y_true, classes=list(range(n_classes)))
fpr = dict(); tpr = dict(); roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_bin[:,i], y_proba[:,i])
    roc_auc[i] = auc(fpr[i], tpr[i])
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
fpr["macro"] = all_fpr; tpr["macro"] = mean_tpr; roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

plt.figure(figsize=(9,8), dpi=300)
plt.plot(fpr["macro"], tpr["macro"], label=f"Macro-average (AUC = {roc_auc['macro']:.3f})", linewidth=4, color="black")
line_styles = ["-", "--", "-.", ":", (0,(3,1,1,1))]
for i,(ls,color) in enumerate(zip(line_styles, PLOT_COLORS)):
    plt.plot(fpr[i], tpr[i], lw=2.5, linestyle=ls, color=color, label=f"{CLASS_NAMES[i]} (AUC={roc_auc[i]:.3f})")
plt.plot([0,1],[0,1],'k--', lw=1.5)
plt.xlim([0,1]); plt.ylim([0,1.05])
plt.xlabel("False Positive Rate", fontsize=14, fontweight="bold")
plt.ylabel("True Positive Rate", fontsize=14, fontweight="bold")
plt.title("3D-CNN — Multiclass ROC Curves", fontsize=18, fontweight="bold")
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
leg = plt.legend(loc="lower right", fontsize=11)
for t in leg.get_texts(): t.set_fontweight("bold")
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "CNN3D_ROC_Curves.png"), dpi=600, bbox_inches="tight")
plt.show()

# ============================
# Predict and save classification maps per scene (batch-patch prediction)
# ============================
def predict_map_with_cnn3d(scene_path, model, patch_size=PATCH_SIZE, batch_size=4096, save_dir=MAP_SAVE_DIR):
    year, season, base = parse_scene_name(scene_path)
    stack, prof = read_stack_reflectance(scene_path)
    bands, H, W = stack.shape
    gdf_proj = gdf_lake.to_crs(prof['crs'])
    lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))

    half = patch_size//2
    pad_width = ((0,0),(half,half),(half,half))
    stack_p = np.pad(stack, pad_width=pad_width, mode='reflect')
    class_flat = np.full(H*W, 255, dtype=np.uint8)
    idx_lake = np.where(lake_mask.flatten())[0]
    print("Predicting for lake pixels:", idx_lake.size)
    n = idx_lake.size
    for i in range(0, n, batch_size):
        sel = idx_lake[i:i+batch_size]
        patches = []
        for ind in sel:
            r = ind // W; c = ind % W
            rp = r + half; cp = c + half
            p = stack_p[:, rp-half:rp+half+1, cp-half:cp+half+1]  # (bands, p, p)
            p = p.astype('float32')
            # Add leading channel -> (1, bands, p, p)
            p = np.expand_dims(p, axis=0)
            patches.append(p)
        # Xb shape: (batch, 1, bands, p, p)
        Xb = np.stack(patches, axis=0)
        Xb_t = torch.from_numpy(Xb).to(DEVICE)
        with torch.no_grad():
            out = model(Xb_t)
            probs = F.softmax(out, dim=1).cpu().numpy()
            preds = probs.argmax(axis=1).astype(np.uint8)
        class_flat[sel] = preds
    class_map = class_flat.reshape(H,W)

    # save GeoTIFF
    out_tif = os.path.join(save_dir, f"{year}_{season}_CNN3D_Map.tif")
    prof2 = prof.copy(); prof2.update(dtype=rasterio.uint8, count=1, nodata=255)
    with rasterio.open(out_tif, "w", **prof2) as dst:
        dst.write(class_map, 1)
    print("Saved CNN3D GeoTIFF:", out_tif)

    # save PNG (clean map only)
    plot_map = class_map.astype(float); plot_map[plot_map==255]=np.nan
    cmap = mcolors.ListedColormap(PLOT_COLORS)
    bounds = [-0.5,0.5,1.5,2.5,3.5,4.5]; norm = mcolors.BoundaryNorm(bounds, cmap.N)
    fig, ax = plt.subplots(figsize=(10,10), dpi=300)
    ax.imshow(plot_map, cmap=cmap, norm=norm, origin='upper')
    ax.set_title(f"{year} {season} (3D-CNN)", fontsize=18, fontweight='bold')
    ax.axis('off')
    out_png = os.path.join(save_dir, f"{year}_{season}_CNN3D_Map.png")
    fig.savefig(out_png, dpi=400, bbox_inches='tight')
    plt.show()
    print("Saved CNN3D PNG:", out_png)
    return class_map

# run prediction maps for all scenes
for scene in scene_files:
    predict_map_with_cnn3d(scene, model, patch_size=PATCH_SIZE, batch_size=4096, save_dir=MAP_SAVE_DIR)

print("All done. Maps in:", MAP_SAVE_DIR)
print("Models in:", MODEL_SAVE_DIR)
print("Excel in:", EXCEL_SAVE_DIR)


#Vison_Transformer

In [None]:
# Full runnable cell: Vision Transformer (ViT) pipeline for Sentinel-2 patch classification
# Edit paths and parameters below before running.

import os
import glob
import math
import random
import numpy as np
import pandas as pd
import rasterio
from rasterio.features import geometry_mask
import geopandas as gpd
from skimage.morphology import dilation, disk
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (classification_report, confusion_matrix,
                             accuracy_score, f1_score, precision_score,
                             recall_score, roc_auc_score, roc_curve, auc,
                             cohen_kappa_score, matthews_corrcoef)
from sklearn.preprocessing import label_binarize

# PyTorch imports
try:
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.utils.data import Dataset, DataLoader
except Exception as e:
    raise ImportError("PyTorch not found. Install it first, e.g. `pip install torch torchvision`") from e

# ============================
# USER: edit these paths & params
# ============================
SCENES_DIR = r"/content/drason"     # e.g. "/content/drive/.../Ml_ AL season"
SHP_PATH   = r"/content/drive/SHP"
MAP_SAVE_DIR = r"/content/drivOutput"
MODEL_SAVE_DIR = r"/content/dutputt"
EXCEL_SAVE_DIR = r"/content/dOutput"
os.makedirs(MAP_SAVE_DIR, exist_ok=True)
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(EXCEL_SAVE_DIR, exist_ok=True)

PATCH_SIZE = 11                    # odd (spatial neighborhood size used to form input patches)
# Note: patch_size_patch below controls the patch token size used by ViT patch embedding
PATCH_SIZE_PATCH = 1               # patch token size (1 = each pixel becomes a token); larger -> fewer tokens
SAMPLES_PER_CLASS_PER_SCENE = 400  # adjust for memory
BATCH_SIZE = 64
EPOCHS = 40
RANDOM_STATE = 42
NUM_CLASSES = 5
# Set expected bands according to your stacked TIFF order. For B2,B3,B4,B5,B6,B7,B8,B8A,B11,B12 -> 10 bands
EXPECTED_BANDS = 10
SCALE_REFLECTANCE = True          # True if values are 0..10000 in TIFF
BLOOM_PERCENTILE = 85             # to set CHL_T2
CHL_T1_FIXED = 20.0               # clean vs nutrient threshold
VEG_SPARSE = 0.20
VEG_DENSE  = 0.40
NDCI_BLOOM_MIN = 0.15
NDVI_BLOOM_MIN = 0.25
BLOOM_DILATE_RADIUS = 2

# ViT hyperparameters
EMBED_DIM = 128
TRANSFORMER_DEPTH = 6
NUM_HEADS = 8
MLP_RATIO = 4.0
DROPOUT = 0.1
CLS_TOKEN = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

random.seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
torch.manual_seed(RANDOM_STATE)

# color palette used for plotting & map (five classes)
CLASS_NAMES = ["Clean","Nutrient","Bloom","SparsePD","DensePD"]
PLOT_COLORS = ["#1f78b4","#33a02c","#ff7f00","#6a3d9a","#e31a1c"]

# ============================
# Helper: read stack reflectance
# ============================
def read_stack_reflectance(path):
    """Read stacked multi-band GeoTIFF (bands x H x W). Convert to float32 reflectance 0..1 if needed."""
    with rasterio.open(path) as src:
        arr = src.read().astype("float32")  # shape (bands, H, W)
        prof = src.profile.copy()
    finite = np.isfinite(arr)
    if SCALE_REFLECTANCE and finite.any() and float(np.nanmax(arr[finite])) > 1.5:
        arr = np.where(finite, arr / 10000.0, np.nan)
    arr = np.clip(arr, 0.0, 1.5)
    return arr, prof

# ============================
# Helper: parse scene file name
# ============================
def parse_scene_name(path):
    base = os.path.basename(path).replace(".tif","")
    parts = base.split("_",1)
    if len(parts)==2:
        year, season = parts
    else:
        year, season = "Unknown", base
    return year, season, base

# ============================
# Helper: compute class map (same rules as before)
# ============================
def build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2):
    # stack: (bands, H, W); expects B2 index 0, B3 index 1, B4 index 2, B5 index 3, B8 index 6 for the standard ordering used earlier
    B2 = stack[0]; B3 = stack[1]; B4 = stack[2]; B5 = stack[3]; B8 = stack[6]
    eps = 1e-9
    ndvi = (B8 - B4) / (B8 + B4 + eps)
    ndci = (B5 - B4) / (B5 + B4 + eps)
    chl = 14.039 + 86.115 * ndci + 194.325 * (ndci ** 2)
    H,W = B4.shape
    class_map = np.full((H,W), np.nan, dtype="float32")

    dense  = (ndvi >= VEG_DENSE) & lake_mask
    sparse = (ndvi >= VEG_SPARSE) & (ndvi < VEG_DENSE) & lake_mask
    water  = (ndvi < VEG_SPARSE) & lake_mask

    class_map[dense]  = 4
    class_map[sparse] = 3

    bloom_core = (chl >= CHL_T2) & water
    bloom_spectral = ((ndci > NDCI_BLOOM_MIN) & (ndvi > NDVI_BLOOM_MIN)) & water
    bloom = bloom_core | bloom_spectral
    if BLOOM_DILATE_RADIUS > 0:
        bloom = dilation(bloom, disk(BLOOM_DILATE_RADIUS)) & water
    class_map[bloom] = 2

    water_remain = water & (~bloom)
    class_map[(chl < CHL_T1) & water_remain] = 0
    class_map[(chl >= CHL_T1) & water_remain] = 1

    return class_map, chl, ndvi, ndci

# ============================
# PASS 1: scan scenes to collect CHL percentiles -> CHL_T2
# ============================
scene_files = sorted(glob.glob(os.path.join(SCENES_DIR, "*.tif")))
if len(scene_files)==0:
    raise RuntimeError("No .tif scenes found in SCENES_DIR. Edit SCENES_DIR path.")

gdf_lake = gpd.read_file(SHP_PATH)
print("Found scenes:", len(scene_files), "Shapefile rows:", len(gdf_lake))

all_chl = []
for scene_path in scene_files:
    print("Collecting chl from:", os.path.basename(scene_path))
    stack, prof = read_stack_reflectance(scene_path)
    H = stack.shape[1]; W = stack.shape[2]
    gdf_proj = gdf_lake.to_crs(prof['crs'])
    lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
    _, chl, _, _ = build_classes_from_stack(stack, lake_mask, CHL_T1_FIXED, CHL_T1_FIXED+10)  # CHL_T2 placeholder
    valid = lake_mask & np.isfinite(chl)
    if np.any(valid):
        all_chl.append(chl[valid])
if len(all_chl)==0:
    raise RuntimeError("No valid chl values inside lake for any scene.")
all_chl = np.concatenate(all_chl)
p50,p75,p85,p90,p95 = np.percentile(all_chl, [50,75,85,90,95])
CHL_T1 = CHL_T1_FIXED
CHL_T2 = float(np.percentile(all_chl, BLOOM_PERCENTILE))
print(f"Chl percentiles: 50%={p50:.2f},75%={p75:.2f},85%={p85:.2f},90%={p90:.2f},95%={p95:.2f}")
print(f"Using CHL_T1={CHL_T1:.2f}, CHL_T2({BLOOM_PERCENTILE}th)={CHL_T2:.2f}")

# ============================
# Patch extraction functions (identical to earlier)
# ============================
def extract_patches_from_scene(stack, class_map, lake_mask, patch_size=11, samples_per_class=400):
    half = patch_size//2
    bands, H, W = stack.shape
    pad_width = ((0,0),(half,half),(half,half))
    stack_p = np.pad(stack, pad_width=pad_width, mode='reflect')
    X = []
    y = []
    flat_cls = class_map.flatten()
    flat_lake = lake_mask.flatten()
    for cls in range(NUM_CLASSES):
        idxs = np.where((flat_cls==cls) & (flat_lake))[0]
        if idxs.size==0:
            continue
        sel = idxs if idxs.size <= samples_per_class else np.random.choice(idxs, samples_per_class, replace=False)
        for ind in sel:
            r = ind // W; c = ind % W
            rp = r + half; cp = c + half
            patch = stack_p[:, rp-half:rp+half+1, cp-half:cp+half+1]  # (bands, p, p)
            patch = np.transpose(patch, (1,2,0))  # (p, p, bands)
            X.append(patch)
            y.append(cls)
    if len(X)==0:
        return np.empty((0,patch_size,patch_size,bands)), np.empty((0,), dtype=int)
    X = np.stack(X, axis=0).astype('float32')
    y = np.array(y, dtype=int)
    return X,y

def build_dataset_from_scenes(scene_files, gdf_lake, patch_size=11, samples_per_class=400):
    X_list=[]; y_list=[]
    for scene_path in scene_files:
        print("Scene:", os.path.basename(scene_path))
        stack, prof = read_stack_reflectance(scene_path)
        H = stack.shape[1]; W = stack.shape[2]
        gdf_proj = gdf_lake.to_crs(prof['crs'])
        lake_mask = geometry_mask(gdf_proj.geometry, transform=prof['transform'], invert=True, out_shape=(H,W))
        class_map, chl, ndvi, ndci = build_classes_from_stack(stack, lake_mask, CHL_T1, CHL_T2)
        Xp, yp = extract_patches_from_scene(stack, class_map, lake_mask, patch_size=patch_size, samples_per_class=samples_per_class)
        if Xp.shape[0]>0:
            X_list.append(Xp); y_list.append(yp)
    if not X_list:
        raise RuntimeError("No patches extracted from scenes. Check masks/thresholds.")
    X_all = np.concatenate(X_list, axis=0)
    y_all = np.concatenate(y_list, axis=0)
    print("Patches total:", X_all.shape, "Label counts:", np.unique(y_all, return_counts=True))
    return X_all, y_all

# ============================
# Build dataset (patches)
# ============================
X, y = build_dataset_from_scenes(scene_files, gdf_lake, patch_size=PATCH_SIZE, samples_per_class=SAMPLES_PER_CLASS_PER_SCENE)

bands_found = X.shape[-1]
if bands_found != EXPECTED_BANDS:
    print(f"Warning: patches have {bands_found} bands (EXPECTED_BANDS={EXPECTED_BANDS}). If order differs, update code.")

# Train/val/test split
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=RANDOM_STATE, stratify=y)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=RANDOM_STATE, stratify=y_temp)
print("Train/Val/Test sizes:", X_train.shape[0], X_val.shape[0], X_test.shape[0])

# ============================
# PyTorch Dataset & Dataloader for ViT
# ============================
class PatchDatasetViT(Dataset):
    def __init__(self, X, y):
        """
        X: numpy (N, p, p, bands)  (spatial-first)
        Return: (C,H,W) float32 where C = bands (channel-first)
        """
        self.X = X
        self.y = y

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        patch = self.X[idx]  # (p, p, bands)
        # convert to (bands, p, p)
        patch = np.transpose(patch, (2,0,1)).astype('float32')
        return torch.from_numpy(patch), torch.tensor(int(self.y[idx]), dtype=torch.long)

train_ds = PatchDatasetViT(X_train, y_train)
val_ds   = PatchDatasetViT(X_val, y_val)
test_ds  = PatchDatasetViT(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
test_loader  = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# ============================
# Compute class weights for loss
# ============================
cls_vals = np.unique(y_train)
weights = compute_class_weight(class_weight='balanced', classes=cls_vals, y=y_train)
class_weights = torch.tensor(weights, dtype=torch.float32).to(DEVICE)
print("Class weights:", dict(zip(cls_vals, weights)))

# ============================
# ViT model definition (simple, configurable)
# ============================
class PatchEmbed(nn.Module):
    """Convert (B, C, H, W) -> (B, num_patches, embed_dim) using Conv2d."""
    def __init__(self, in_chans, embed_dim, patch_size):
        super().__init__()
        self.patch_size = patch_size
        # conv to produce patch embeddings; out_channels = embed_dim
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
    def forward(self, x):
        # x: (B, C, H, W)
        x = self.proj(x)           # (B, embed_dim, H/ps, W/ps)
        B, E, Hf, Wf = x.shape
        x = x.flatten(2).transpose(1,2)  # (B, num_patches, embed_dim)
        return x, (Hf, Wf)

class ViT(nn.Module):
    def __init__(self, in_chans, num_classes, embed_dim=128, depth=6, num_heads=8, mlp_ratio=4.0, patch_size=1, dropout=0.1, use_cls_token=True):
        super().__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.patch_embed = PatchEmbed(in_chans, embed_dim, patch_size)
        self.use_cls_token = use_cls_token

        # placeholder for num_patches until forward (pos_embed shape depends on grid)
        self.cls_token = nn.Parameter(torch.zeros(1,1,embed_dim)) if use_cls_token else None
        self.pos_embed = None

        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim*mlp_ratio), dropout=dropout, activation='gelu')
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)
        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

        # init
        nn.init.trunc_normal_(self.cls_token, std=0.02) if self.cls_token is not None else None
        # pos_embed will be created on first forward when we know num_patches

    def _init_pos_embed(self, num_patches, device):
        # create pos_embed shape (1, n_tokens, embed_dim) ; n_tokens = num_patches + (1 if cls)
        n_tokens = num_patches + (1 if self.use_cls_token else 0)
        self.pos_embed = nn.Parameter(torch.zeros(1, n_tokens, self.embed_dim).to(device))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        # x: (B, C, H, W)
        B = x.shape[0]
        x, (Hf, Wf) = self.patch_embed(x)  # x: (B, num_patches, embed_dim)
        num_patches = x.shape[1]
        device = x.device
        if self.pos_embed is None or self.pos_embed.shape[1] != num_patches + (1 if self.use_cls_token else 0):
            self._init_pos_embed(num_patches, device)

        if self.use_cls_token:
            cls_tok = self.cls_token.expand(B, -1, -1)  # (B,1,embed_dim)
            x = torch.cat([cls_tok, x], dim=1)          # (B, 1+num_patches, embed_dim)
        x = x + self.pos_embed
        # transformer expects (seq_len, batch, embed_dim) by default, so transpose
        x = x.transpose(0,1)  # (seq_len, B, E)
        x = self.transformer(x)  # (seq_len, B, E)
        x = x.transpose(0,1)  # (B, seq_len, E)
        x = self.norm(x)
        # classification token or mean pooling
        if self.use_cls_token:
            rep = x[:,0]  # (B, E)
        else:
            rep = x.mean(dim=1)
        logits = self.head(rep)
        return logits

# instantiate model
model = ViT(in_chans=bands_found, num_classes=NUM_CLASSES, embed_dim=EMBED_DIM, depth=TRANSFORMER_DEPTH,
            num_heads=NUM_HEADS, mlp_ratio=MLP_RATIO, patch_size=PATCH_SIZE_PATCH, dropout=DROPOUT, use_cls_token=CLS_TOKEN).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss(weight=class_weights)

# ============================
# Training loop
# ============================
best_val_acc = 0.0
patience = 8
patience_counter = 0
history = {"train_loss":[], "val_loss":[], "train_acc":[], "val_acc":[]}

for epoch in range(1, EPOCHS+1):
    model.train()
    running_loss = 0.0; correct=0; total=0
    for xb, yb in train_loader:
        xb = xb.to(DEVICE); yb = yb.to(DEVICE)
        optimizer.zero_grad()
        out = model(xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * xb.size(0)
        preds = out.argmax(dim=1)
        correct += (preds==yb).sum().item()
        total += xb.size(0)
    train_loss = running_loss/total
    train_acc = correct/total

    # validation
    model.eval()
    vloss=0.0; vcorrect=0; vtotal=0
    with torch.no_grad():
        for xb, yb in val_loader:
            xb=xb.to(DEVICE); yb=yb.to(DEVICE)
            out = model(xb)
            loss = criterion(out, yb)
            vloss += loss.item() * xb.size(0)
            preds = out.argmax(dim=1)
            vcorrect += (preds==yb).sum().item()
            vtotal += xb.size(0)
    val_loss = vloss/vtotal
    val_acc = vcorrect/vtotal

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["train_acc"].append(train_acc)
    history["val_acc"].append(val_acc)

    print(f"Epoch {epoch}/{EPOCHS}  train_loss={train_loss:.4f} train_acc={train_acc:.4f}  val_loss={val_loss:.4f} val_acc={val_acc:.4f}")

    # early stopping & save best
    if val_acc > best_val_acc + 1e-5:
        best_val_acc = val_acc
        torch.save(model.state_dict(), os.path.join(MODEL_SAVE_DIR, "vit_best.pth"))
        patience_counter = 0
        print("  Saved best model.")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping.")
            break

# load best
model.load_state_dict(torch.load(os.path.join(MODEL_SAVE_DIR, "vit_best.pth"), map_location=DEVICE))
model.eval()

# ============================
# Evaluate on test set
# ============================
y_true = []
y_pred = []
y_proba = []
with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(DEVICE)
        out = model(xb)
        probs = F.softmax(out, dim=1).cpu().numpy()
        preds = out.argmax(dim=1).cpu().numpy()
        y_true.extend(yb.numpy().tolist())
        y_pred.extend(preds.tolist())
        y_proba.extend(probs.tolist())
y_true = np.array(y_true); y_pred = np.array(y_pred); y_proba = np.vstack(y_proba)

# Metrics
acc  = accuracy_score(y_true, y_pred)
f1_m = f1_score(y_true, y_pred, average='macro')
f1_w = f1_score(y_true, y_pred, average='weighted')
prec = precision_score(y_true, y_pred, average='macro')
rec  = recall_score(y_true, y_pred, average='macro')
kappa = cohen_kappa_score(y_true, y_pred)
mcc = matthews_corrcoef(y_true, y_pred)
roc_macro = roc_auc_score(label_binarize(y_true, classes=list(range(NUM_CLASSES))), y_proba, multi_class='ovr', average='macro')

print("\n=== ViT Test-set Performance ===")
print(f"Accuracy          : {acc:.3f}")
print(f"F1-score (macro)  : {f1_m:.3f}")
print(f"F1-score (weighted): {f1_w:.3f}")
print(f"Precision (macro) : {prec:.3f}")
print(f"Recall (macro)    : {rec:.3f}")
print(f"Cohen's Kappa     : {kappa:.3f}")
print(f"MCC               : {mcc:.3f}")
print(f"ROC–AUC (macro)   : {roc_macro:.3f}")

print("\nClassification Report (per class):")
print(classification_report(y_true, y_pred, target_names=CLASS_NAMES))

# ============================
# Save performance to Excel
# ============================
summary_metrics = {
    "Model": ["ViT"],
    "Accuracy": [acc],
    "F1_macro": [f1_m],
    "F1_weighted": [f1_w],
    "Precision_macro": [prec],
    "Recall_macro": [rec],
    "Cohen_Kappa": [kappa],
    "MCC": [mcc],
    "ROC_AUC_macro": [roc_macro]
}
df_summary = pd.DataFrame(summary_metrics)
df_per_class = pd.DataFrame(classification_report(y_true, y_pred, target_names=CLASS_NAMES, output_dict=True)).transpose()

excel_out = os.path.join(EXCEL_SAVE_DIR, "ViT_Performance_Metrics.xlsx")
with pd.ExcelWriter(excel_out, engine="openpyxl") as writer:
    df_summary.to_excel(writer, sheet_name="Summary", index=False)
    df_per_class.to_excel(writer, sheet_name="Per_Class")
print("Saved performance Excel to:", excel_out)

# ============================
# Styled confusion matrix (with cell lines)
# ============================
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(9,7))
ax = sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                 xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
                 annot_kws={"size":14, "weight":"bold"},
                 linewidths=2, linecolor="black")
ax.set_title("ViT — Confusion Matrix (Test)", fontsize=20, fontweight="bold")
ax.set_xlabel("Predicted", fontsize=16, fontweight="bold"); ax.set_ylabel("Actual", fontsize=16, fontweight="bold")
ax.tick_params(axis="both", labelsize=12, width=2, length=6)
for lbl in ax.get_xticklabels()+ax.get_yticklabels(): lbl.set_fontweight("bold")
cbar = ax.collections[0].colorbar; cbar.ax.tick_params(labelsize=12); cbar.outline.set_linewidth(2)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "ViT_Confusion_Matrix.png"), dpi=600, bbox_inches="tight")
plt.show()

# ============================
# ROC curves plot (multiclass)
# ============================
n_classes = NUM_CLASSES
y_bin = label_binarize(y_true, classes=list(range(n_classes)))
fpr = dict(); tpr = dict(); roc_auc = dict()
for i in range(n_classes):
    fpr[i], tpr[i], _ = roc_curve(y_bin[:,i], y_proba[:,i])
    roc_auc[i] = auc(fpr[i], tpr[i])
all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
mean_tpr = np.zeros_like(all_fpr)
for i in range(n_classes):
    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
mean_tpr /= n_classes
fpr["macro"] = all_fpr; tpr["macro"] = mean_tpr; roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

plt.figure(figsize=(9,8), dpi=300)
plt.plot(fpr["macro"], tpr["macro"], label=f"Macro-average (AUC = {roc_auc['macro']:.3f})", linewidth=4, color="black")
line_styles = ["-", "--", "-.", ":", (0,(3,1,1,1))]
for i,(ls,color) in enumerate(zip(line_styles, PLOT_COLORS)):
    plt.plot(fpr[i], tpr[i], lw=2.5, linestyle=ls, color=color, label=f"{CLASS_NAMES[i]} (AUC={roc_auc[i]:.3f})")
plt.plot([0,1],[0,1],'k--', lw=1.5)
plt.xlim([0,1]); plt.ylim([0,1.05])
plt.xlabel("False Positive Rate", fontsize=14, fontweight="bold")
plt.ylabel("True Positive Rate", fontsize=14, fontweight="bold")
plt.title("ViT — Multiclass ROC Curves", fontsize=18, fontweight="bold")
plt.xticks(fontsize=12); plt.yticks(fontsize=12)
leg = plt.legend(loc="lower right", fontsize=11)
for t in leg.get_texts(): t.set_fontweight("bold")
plt.grid(True, linestyle="--", alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.join(MAP_SAVE_DIR, "ViT_ROC_Curves.png"), dpi=600, bbox_inches="tight")
plt.show()

