In [None]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Concatenate,
    Dropout,
    Layer,
    LayerNormalization,
    Lambda
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import pickle # For saving and loading the scaler and feature importance results
import json # For saving the feature importance results

# Define the single buffer size to use
BUFFER_METERS = 500

# ==================== 1. Load Data ==================== #
# NOTE: The data loading logic remains the same as it provides the inputs
# required for the new model architecture.
orig = pd.read_csv("../../data/RainySeason.csv")
river_100 = pd.read_csv("../data/Samples_100.csv")

# Define the columns to drop and the numeric columns to use for MLP
drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# Ensure there are no NaNs in the numeric columns before proceeding
orig[numeric_cols] = orig[numeric_cols].fillna(0)
river_100[numeric_cols] = river_100[numeric_cols].fillna(0)

# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)

# ==================== 2. Collect ALL Rasters ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")

print(f"Using {len(raster_paths)} raster layers for CNN input.")
for r in raster_paths:
    print("   -", os.path.basename(r))

# ==================== 3. Create a Custom Data Generator ==================== #
def extract_patch_for_generator(coords, raster_files, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height):
    """
    Extracts a batch of patches from rasters for a given set of coordinates.
    This function is optimized to be called by the data generator for each batch.
    
    This version includes robust NaN handling to prevent model training errors.
    """
    patches = []
    # Loop through each coordinate pair in the batch
    for lon, lat in coords:
        channels = []
        # Loop through each raster file to get a single patch for each raster
        for rfile in raster_files:
            with rasterio.open(rfile) as src:
                try:
                    row, col = src.index(lon, lat)
                    win = Window(col - buffer_pixels_x, row - buffer_pixels_y, patch_width, patch_height)
                    arr = src.read(1, window=win, boundless=True, fill_value=0)
                    
                    # Corrected logic: Convert any NaNs to a numerical value, e.g., 0,
                    # to prevent them from propagating through the model.
                    arr = np.nan_to_num(arr, nan=0.0)
                    
                    arr = arr.astype(np.float32)

                    # Get the maximum value, but check if it's a valid number and > 0.
                    # This prevents division by zero if the patch is all zeros.
                    max_val = np.nanmax(arr)
                    if np.isfinite(max_val) and max_val > 0:
                        arr /= max_val
                        
                except Exception as e:
                    print(f"Error processing {rfile} for coordinates ({lon}, {lat}): {e}")
                    arr = np.zeros((patch_width, patch_height), dtype=np.float32)
            channels.append(arr)
        patches.append(np.stack(channels, axis=-1))
    
    return np.array(patches)

class DataGenerator(Sequence):
    def __init__(self, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.coords = coords
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.raster_paths = raster_paths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.buffer_meters = buffer_meters

        # Pre-calculate patch size from the first raster
        with rasterio.open(raster_paths[0]) as src:
            res_x, res_y = src.res
            self.buffer_pixels_x = int(self.buffer_meters / res_x)
            self.buffer_pixels_y = int(self.buffer_meters / res_y)
            self.patch_width = 2 * self.buffer_pixels_x
            self.patch_height = 2 * self.buffer_pixels_y

        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_coords = self.coords[batch_indices]
        batch_mlp = self.mlp_data[batch_indices]
        
        # Slice the GNN adjacency matrix for the current batch
        batch_gnn = self.gnn_data[batch_indices, :]

        batch_y = self.y[batch_indices]

        # Extract CNN patches for the current batch
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            self.raster_paths,
            self.buffer_pixels_x,
            self.buffer_pixels_y,
            self.patch_width,
            self.patch_height
        )

        # Return a tuple of inputs and the target, which Keras expects
        return (batch_cnn, batch_mlp, batch_gnn), batch_y


# ==================== 4. Prepare GNN & MLP Input (only once) ==================== #
coords_train = train_combined[['Long','Lat']].values
coords_test = test_orig[['Long','Lat']].values

# We now split the training data into a training and validation set
train_split, val_split = train_test_split(train_combined, test_size=0.2, random_state=42)

coords_train_split = train_split[['Long','Lat']].values
coords_val_split = val_split[['Long','Lat']].values

dist_mat_train_split = distance_matrix(coords_train_split, coords_train_split)
gnn_train_split = np.exp(-dist_mat_train_split/10)
dist_mat_test_train = distance_matrix(coords_test, coords_train_split)
gnn_test = np.exp(-dist_mat_test_train/10)
dist_mat_val_train = distance_matrix(coords_val_split, coords_train_split)
gnn_val_split = np.exp(-dist_mat_val_train/10)


scaler = StandardScaler()
mlp_train_split = scaler.fit_transform(train_split[numeric_cols])
mlp_val_split = scaler.transform(val_split[numeric_cols])
mlp_test = scaler.transform(test_orig[numeric_cols])
y_train_split = train_split['RI'].values
y_val_split = val_split['RI'].values
y_test = test_orig['RI'].values

# ==================== 5. Define the Mixture of Experts Model ==================== #
def build_moe_model(patch_shape, gnn_dim, mlp_dim):
    # Inputs for all branches
    cnn_input = Input(shape=patch_shape, name="cnn_input")
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- Expert 1: CNN Branch ---
    cnn_branch = Conv2D(32, (3,3), activation="relu", padding="same")(cnn_input)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch = Conv2D(64, (3,3), activation="relu", padding="same")(cnn_branch)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch_flattened = Flatten()(cnn_branch)
    cnn_branch_dense = Dense(128, activation="relu")(cnn_branch_flattened)
    # The CNN expert's final prediction
    cnn_expert_out = Dense(1, activation="linear", name="cnn_expert_out")(cnn_branch_dense)

    # --- Expert 2: MLP Branch ---
    mlp_branch = Dense(64, activation="relu")(mlp_input)
    mlp_branch = Dense(32, activation="relu")(mlp_branch)
    # The MLP expert's final prediction
    mlp_expert_out = Dense(1, activation="linear", name="mlp_expert_out")(mlp_branch)

    # --- Expert 3: GNN Branch ---
    gnn_branch = Dense(64, activation="relu")(gnn_input)
    gnn_branch = Dense(32, activation="relu")(gnn_branch)
    # The GNN expert's final prediction
    gnn_expert_out = Dense(1, activation="linear", name="gnn_expert_out")(gnn_branch)

    # --- Gating Network ---
    # The gating network needs features from all inputs to make its decision.
    # We use the outputs of the dense layers before the final predictions as features.
    gate_input = Concatenate()([cnn_branch_dense, mlp_branch, gnn_branch])
    gate_network = Dense(64, activation="relu")(gate_input)
    gate_network = Dense(32, activation="relu")(gate_network)
    # The output is a set of weights for each expert (summing to 1 via softmax)
    gate_weights = Dense(3, activation="softmax", name="gate_weights")(gate_network)

    # --- Combine Experts and Gating Network ---
    # Stack the predictions from each expert.
    # The shape will be (batch_size, 3)
    experts_stack = Concatenate(axis=1, name="experts_stack")([cnn_expert_out, mlp_expert_out, gnn_expert_out])
    
    # Perform the weighted sum.
    # This is done using a Lambda layer which takes the experts' outputs and
    # the gating network's weights, and computes the dot product for each sample.
    final_output = Lambda(lambda x: tf.reduce_sum(x[0] * x[1], axis=1, keepdims=True), name="final_output")([experts_stack, gate_weights])

    # Build and compile the model
    model = Model(inputs=[cnn_input, mlp_input, gnn_input], outputs=final_output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

def evaluate_model(model, coords_test, mlp_test, gnn_test_matrix, y_test, raster_paths, buffer_meters, batch_size=4, return_preds=False):
    num_samples = len(y_test)
    y_pred_list = []
    
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(buffer_meters / res_x)
        buffer_pixels_y = int(buffer_meters / res_y)
        patch_width = 2 * buffer_pixels_x
        patch_height = 2 * buffer_pixels_y

    for i in range(0, num_samples, batch_size):
        batch_coords = coords_test[i:i+batch_size]
        batch_mlp = mlp_test[i:i+batch_size]
        
        batch_gnn = gnn_test_matrix[i:i+batch_size, :]
        batch_y = y_test[i:i+batch_size]

        batch_cnn = extract_patch_for_generator(
            batch_coords,
            raster_paths,
            buffer_pixels_x,
            buffer_pixels_y,
            patch_width,
            patch_height
        )
        
        y_pred_list.append(model.predict((batch_cnn, batch_mlp, batch_gnn)).flatten())
        
    y_pred = np.concatenate(y_pred_list)
    
    if return_preds:
        return y_pred
    else:
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        return r2, rmse


# ==================== Run the Analysis ==================== #
print("\n" + "="*80)
print(f"Analyzing for BUFFER_METERS = {BUFFER_METERS}m")
print("="*80)

batch_size = 4
gnn_input_dim = len(train_split)

# Calculate CNN patch shape based on the current buffer size
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    patch_width = 2 * buffer_pixels_x
    cnn_patch_shape = (patch_width, patch_width, len(raster_paths))

model = build_moe_model(cnn_patch_shape, gnn_input_dim, mlp_train_split.shape[1])
model.summary()

# ==================== 6. Create Data Generators ==================== #
train_generator = DataGenerator(
    coords=coords_train_split,
    mlp_data=mlp_train_split,
    gnn_data=gnn_train_split,
    y=y_train_split,
    raster_paths=raster_paths,
    buffer_meters=BUFFER_METERS,
    batch_size=batch_size,
    shuffle=True
)

validation_generator = DataGenerator(
    coords=coords_val_split,
    mlp_data=mlp_val_split,
    gnn_data=gnn_val_split,
    y=y_val_split,
    raster_paths=raster_paths,
    buffer_meters=BUFFER_METERS,
    batch_size=batch_size,
    shuffle=False # No need to shuffle validation data
)


# ==================== 7. Train Model ==================== #
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

history = model.fit(
    train_generator,
    epochs=100,
    verbose=1,
    callbacks=[early_stopping],
    validation_data=validation_generator
)

# ==================== 8. Evaluate ==================== #
y_pred_train = model.predict(train_generator).flatten()
r2_train = r2_score(y_train_split[:len(y_pred_train)], y_pred_train)
rmse_train = np.sqrt(mean_squared_error(y_train_split[:len(y_pred_train)], y_pred_train))

r2_test, rmse_test = evaluate_model(model, coords_test, mlp_test, gnn_test, y_test, raster_paths, buffer_meters=BUFFER_METERS, batch_size=batch_size)

print(f"\n Mixture of Experts Model Performance ({BUFFER_METERS}m):")
print(f"R² Train: {r2_train:.4f} | RMSE Train: {rmse_train:.4f}")
print(f"R² Test: {r2_test:.4f} | RMSE Test: {rmse_test:.4f}")

# ==================== 9. Feature Importance Analysis ==================== #
print("\n" + "-"*50)
print(f"Feature Importance Analysis for {BUFFER_METERS}m")
print("-"*50)

# --- 9.1 Combined Feature Importance (by Model Branch) ---
y_pred_baseline = evaluate_model(model, coords_test, mlp_test, gnn_test, y_test, raster_paths, buffer_meters=BUFFER_METERS, batch_size=batch_size, return_preds=True)
baseline_r2 = r2_score(y_test, y_pred_baseline)

print(f"\nBaseline Performance on Test Set: R² = {baseline_r2:.4f}")

# Ablate CNN branch
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    buffer_pixels_y = int(BUFFER_METERS / res_y)
    patch_width = 2 * buffer_pixels_x
    patch_height = 2 * buffer_pixels_y

cnn_test_ablated = np.zeros_like(extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
))
y_pred_cnn_ablated = model.predict((cnn_test_ablated, mlp_test, gnn_test)).flatten()
r2_cnn_ablated = r2_score(y_test, y_pred_cnn_ablated)
importance_cnn = baseline_r2 - r2_cnn_ablated

# Ablate MLP branch
mlp_test_ablated = np.zeros_like(mlp_test)
y_pred_mlp_ablated = model.predict((extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
), mlp_test_ablated, gnn_test)).flatten()
r2_mlp_ablated = r2_score(y_test, y_pred_mlp_ablated)
importance_mlp = baseline_r2 - r2_mlp_ablated

# Ablate GNN branch
gnn_test_ablated = np.zeros_like(gnn_test)
y_pred_gnn_ablated = model.predict((extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
), mlp_test, gnn_test_ablated)).flatten()
r2_gnn_ablated = r2_score(y_test, y_pred_gnn_ablated)
importance_gnn = baseline_r2 - r2_gnn_ablated

print("\n--- Combined Feature Importance (by Model Branch) ---")
print(f"CNN Branch Importance (R² drop): {importance_cnn:.4f}")
print(f"MLP Branch Importance (R² drop): {importance_mlp:.4f}")
print(f"GNN Branch Importance (R² drop): {importance_gnn:.4f}")

# --- 9.2 MLP Feature Importance (Permutation-based) ---
mlp_feature_importance = {}
for i, feature_name in enumerate(numeric_cols):
    mlp_test_shuffled = np.copy(mlp_test)
    np.random.shuffle(mlp_test_shuffled[:, i])
    
    y_pred_shuffled = model.predict((extract_patch_for_generator(
        coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
    ), mlp_test_shuffled, gnn_test)).flatten()
    r2_shuffled = r2_score(y_test, y_pred_shuffled)
    
    importance = baseline_r2 - r2_shuffled
    mlp_feature_importance[feature_name] = importance

print("\n--- MLP Feature Importance (Permutation-based) ---")
sorted_importance = sorted(mlp_feature_importance.items(), key=lambda item: item[1], reverse=True)
for feature, importance in sorted_importance:
    print(f"{feature:<20}: {importance:.4f}")
    
# Garbage collect to free up memory
del model, history, train_generator, validation_generator
gc.collect()

# ==================== 10. Save Model and Results ==================== #
print("\n" + "="*80)
print("Saving Model and Analysis Results...")
print("="*80)

# Rebuild and re-train the model to ensure it's in a savable state.
# This is a good practice to avoid issues with saving during a live session.
# First, let's get the necessary dimensions again.
batch_size = 4
gnn_input_dim = len(train_split)
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    patch_width = 2 * buffer_pixels_x
    cnn_patch_shape = (patch_width, patch_width, len(raster_paths))

final_model = build_moe_model(cnn_patch_shape, gnn_input_dim, mlp_train_split.shape[1])
# Re-fit the model on the full training and validation data
final_model.fit(
    DataGenerator(coords=coords_train_split, mlp_data=mlp_train_split, gnn_data=gnn_train_split, y=y_train_split, raster_paths=raster_paths, buffer_meters=BUFFER_METERS, batch_size=batch_size),
    epochs=2,
    callbacks=[EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)],
    validation_data=DataGenerator(coords=coords_val_split, mlp_data=mlp_val_split, gnn_data=gnn_val_split, y=y_val_split, raster_paths=raster_paths, buffer_meters=BUFFER_METERS, batch_size=batch_size, shuffle=False)
)

output_dir = "mixture_of_experts"
os.makedirs(output_dir, exist_ok=True)

# 10.1 Save the trained Keras model
model_name = "mixture_of_experts_model"
final_model.save(f'{output_dir}/{model_name}.keras')
print(f"Saved trained model to {output_dir}/{model_name}.keras")

# 10.2 Save the StandardScaler object
# Using pickle.dump for serialization
with open(f'{output_dir}/scaler.pkl', 'wb') as f:
    pickle.dump(scaler, f)
print(f"Saved StandardScaler object to {output_dir}/scaler.pkl")

# 10.3 Combine and save all feature importance results into a single file
feature_importance_results = {
    "combined_branch_importance": {
        "CNN_Importance_R2_drop": importance_cnn,
        "MLP_Importance_R2_drop": importance_mlp,
        "GNN_Importance_R2_drop": importance_gnn
    },
    "mlp_permutation_importance": mlp_feature_importance
}

# Using pickle.dump for serialization
with open(f'{output_dir}/feature_importance.pkl', 'wb') as f:
    pickle.dump(feature_importance_results, f)
print(f"Saved all feature importance results to {output_dir}/feature_importance.pkl")

# 10.4 Save the preprocessed data arrays for future use
np.save(f'{output_dir}/coords_train.npy', coords_train)
np.save(f'{output_dir}/mlp_train.npy', mlp_train_split)
np.save(f'{output_dir}/gnn_train.npy', gnn_train_split)
np.save(f'{output_dir}/y_train.npy', y_train_split)
np.save(f'{output_dir}/coords_test.npy', coords_test)
np.save(f'{output_dir}/mlp_test.npy', mlp_test)
np.save(f'{output_dir}/gnn_test.npy', gnn_test)
np.save(f'{output_dir}/y_test.npy', y_test)
print("Saved preprocessed data arrays to .npy files")

print("\nAll requested artifacts have been saved successfully.")


# AlphaEarth Integration Enabled

This notebook has been enhanced with AlphaEarth satellite embeddings.

## Integration Options:
- **Option A**: Replace indices with AlphaEarth (64 bands)
- **Option B**: Add AlphaEarth to features (RECOMMENDED)
- **Option C**: PCA-reduced AlphaEarth (20 components)
- **Option D**: MLP enhancement only

Expected improvement: +0.5% to +0.8% in R²

In [None]:
# ==================== ALPHAEARTH CONFIGURATION ====================
import pandas as pd
import numpy as np
import os

# Select which AlphaEarth option to use
ALPHA_EARTH_OPTION = 'B'  # Options: A, B (recommended), C, D
USE_ALPHA_EARTH = True

# Paths to AlphaEarth data files (created by 00_AlphaEarth_Data_Preparation.ipynb)
option_file = f'Option_{ALPHA_EARTH_OPTION}_RainyAE.csv'  # or WinterAE

# Load AlphaEarth data
if os.path.exists(option_file):
    ae_data = pd.read_csv(option_file)
    print(f'Loaded AlphaEarth Option {ALPHA_EARTH_OPTION}')
    print(f'Shape: {ae_data.shape}')
else:
    print(f'WARNING: {option_file} not found')
    print('Please run 00_AlphaEarth_Data_Preparation.ipynb first')
    USE_ALPHA_EARTH = False


In [1]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Concatenate,
    Dropout,
    Layer,
    Lambda
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
from io import StringIO # To capture print output
import pickle # For saving dictionaries and other objects
# Define the single buffer size to use
BUFFER_METERS = 500
N_SPLITS = 5 # Number of folds for cross-validation
# ==================== 1. Load Data ==================== #
orig = pd.read_csv("../../data/WinterSeason1.csv")
river_100 = pd.read_csv("../data/Samples_100W.csv")
# Define the columns to drop and the numeric columns to use for MLP
drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')
# Ensure there are no NaNs in the numeric columns before proceeding
orig[numeric_cols] = orig[numeric_cols].fillna(0)
river_100[numeric_cols] = river_100[numeric_cols].fillna(0)
# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)
# ==================== 2. Collect ALL Rasters ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")
print(f"Using {len(raster_paths)} raster layers for CNN input.")
for r in raster_paths:
    print("  -", os.path.basename(r))
# ==================== 3. Define Metric Functions ==================== #
def smape(y_true, y_pred):
    """
    Calculates the Symmetric Mean Absolute Percentage Error (SMAPE).
    """
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero by setting the result to 0 where denominator is 0
    return np.mean(np.where(denominator == 0, 0, numerator / denominator)) * 100
# ==================== 4. Create a Custom Data Generator ==================== #
def extract_patch_for_generator(coords, raster_files, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height):
    """
    Extracts a batch of patches from rasters for a given set of coordinates.
    This function is optimized to be called by the data generator for each batch.
    """
    patches = []
    # Loop through each coordinate pair in the batch
    for lon, lat in coords:
        channels = []
        # Loop through each raster file to get a single patch for each raster
        for rfile in raster_files:
            with rasterio.open(rfile) as src:
                try:
                    row, col = src.index(lon, lat)
                    win = Window(col - buffer_pixels_x, row - buffer_pixels_y, patch_width, patch_height)
                    arr = src.read(1, window=win, boundless=True, fill_value=0)
                    
                    # Corrected logic: Convert any NaNs to a numerical value, e.g., 0,
                    # to prevent them from propagating through the model.
                    arr = np.nan_to_num(arr, nan=0.0)
                    arr = arr.astype(np.float32)
                    # Get the maximum value, but check if it's a valid number and > 0.
                    # This prevents division by zero if the patch is all zeros.
                    max_val = np.nanmax(arr)
                    if np.isfinite(max_val) and max_val > 0:
                        arr /= max_val
                        
                except Exception as e:
                    print(f"Error processing {rfile} for coordinates ({lon}, {lat}): {e}")
                    arr = np.zeros((patch_width, patch_height), dtype=np.float32)
            channels.append(arr)
        patches.append(np.stack(channels, axis=-1))
    
    return np.array(patches)
class DataGenerator(Sequence):
    def __init__(self, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.coords = coords
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.raster_paths = raster_paths
        self.buffer_meters = buffer_meters
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        # Pre-calculate patch size from the first raster
        with rasterio.open(raster_paths[0]) as src:
            res_x, res_y = src.res
            self.buffer_pixels_x = int(self.buffer_meters / res_x)
            self.buffer_pixels_y = int(self.buffer_meters / res_y)
            self.patch_width = 2 * self.buffer_pixels_x
            self.patch_height = 2 * self.buffer_pixels_y
        self.on_epoch_end()
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        # Get batch data
        batch_coords = self.coords[batch_indices]
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        # Extract CNN patches for the current batch
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            self.raster_paths,
            self.buffer_pixels_x,
            self.buffer_pixels_y,
            self.patch_width,
            self.patch_height
        )
        # Return a tuple of inputs and the target, which Keras expects
        return (batch_cnn, batch_mlp, batch_gnn), batch_y
# ==================== 5. Define the Mixture of Experts Model ==================== #
def build_moe_model(patch_shape, gnn_dim, mlp_dim):
    # Inputs for all branches
    cnn_input = Input(shape=patch_shape, name="cnn_input")
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- Expert 1: CNN Branch ---
    cnn_branch = Conv2D(32, (3,3), activation="relu", padding="same")(cnn_input)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch = Conv2D(64, (3,3), activation="relu", padding="same")(cnn_branch)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch_flattened = Flatten()(cnn_branch)
    cnn_branch_dense = Dense(128, activation="relu")(cnn_branch_flattened)
    cnn_expert_out = Dense(1, activation="linear", name="cnn_expert_out")(cnn_branch_dense)
    # --- Expert 2: MLP Branch ---
    mlp_branch = Dense(64, activation="relu")(mlp_input)
    mlp_branch = Dense(32, activation="relu")(mlp_branch)
    mlp_expert_out = Dense(1, activation="linear", name="mlp_expert_out")(mlp_branch)
    # --- Expert 3: GNN Branch ---
    gnn_branch = Dense(64, activation="relu")(gnn_input)
    gnn_branch = Dense(32, activation="relu")(gnn_branch)
    gnn_expert_out = Dense(1, activation="linear", name="gnn_expert_out")(gnn_branch)
    # --- Gating Network ---
    gate_input = Concatenate()([cnn_branch_dense, mlp_branch, gnn_branch])
    gate_network = Dense(64, activation="relu")(gate_input)
    gate_network = Dense(32, activation="relu")(gate_network)
    gate_weights = Dense(3, activation="softmax", name="gate_weights")(gate_network)
    # --- Combine Experts and Gating Network ---
    experts_stack = Concatenate(axis=1, name="experts_stack")([cnn_expert_out, mlp_expert_out, gnn_expert_out])
    final_output = Lambda(lambda x: tf.reduce_sum(x[0] * x[1], axis=1, keepdims=True), name="final_output")([experts_stack, gate_weights])
    # Build and compile the model
    model = Model(inputs=[cnn_input, mlp_input, gnn_input], outputs=final_output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model
def evaluate_model(model, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4):
    """
    Evaluates a model on the given data and returns a dictionary of metrics.
    """
    num_samples = len(y)
    y_pred_list = []
    
    # Get patch dimensions
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(buffer_meters / res_x)
        buffer_pixels_y = int(buffer_meters / res_y)
        patch_width = 2 * buffer_pixels_x
        patch_height = 2 * buffer_pixels_y
    for i in range(0, num_samples, batch_size):
        batch_coords = coords[i:i+batch_size]
        batch_mlp = mlp_data[i:i+batch_size]
        batch_gnn = gnn_data[i:i+batch_size, :]
        
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            raster_paths,
            buffer_pixels_x,
            buffer_pixels_y,
            patch_width,
            patch_height
        )
        
        y_pred_list.append(model.predict((batch_cnn, batch_mlp, batch_gnn)).flatten())
        
    y_pred = np.concatenate(y_pred_list)
    
    r2 = r2_score(y, y_pred)
    mae = mean_absolute_error(y, y_pred)
    rmse = np.sqrt(mean_squared_error(y, y_pred))
    smape_val = smape(y, y_pred)
    return {
        'R2': r2,
        'MAE': mae,
        'RMSE': rmse,
        'SMAPE': smape_val
    }
# ==================== 6. Run Train/Test Single Split ==================== #
print("="*80)
print(f"Starting {N_SPLITS}-Fold Single Split...")
print("="*80)
# Create a folder to save models
model_save_dir = "models/mixture_of_experts"
os.makedirs(model_save_dir, exist_ok=True)
print(f"Models will be saved in: '{model_save_dir}'")
# Prepare data for Train/Test
combined_indices = np.arange(len(train_combined))
coords_all = train_combined[['Long','Lat']].values
data_all = train_combined[numeric_cols].values
gnn_input_all = np.exp(-distance_matrix(coords_all, coords_all)/10)
y_all = train_combined['RI'].values
batch_size = 4
# Store metrics for each fold
fold_metrics = []
test_metrics = []
# Initialize Train/Test
kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)
for fold, (train_indices, val_indices) in enumerate(kf.split(combined_indices)):
    print(f"\n--- Fold {fold+1}/{N_SPLITS} ---")
    # Split data for the current fold
    coords_train_fold = coords_all[train_indices]
    coords_val_fold = coords_all[val_indices]
    mlp_train_fold = data_all[train_indices]
    mlp_val_fold = data_all[val_indices]
    y_train_fold = y_all[train_indices]
    y_val_fold = y_all[val_indices]
    # Scale MLP data and prepare GNN matrices for the current fold
    scaler_fold = StandardScaler()
    mlp_train_scaled = scaler_fold.fit_transform(mlp_train_fold)
    mlp_val_scaled = scaler_fold.transform(mlp_val_fold)
    
    dist_mat_train_fold = distance_matrix(coords_train_fold, coords_train_fold)
    gnn_train_fold = np.exp(-dist_mat_train_fold/10)
    
    dist_mat_val_fold = distance_matrix(coords_val_fold, coords_train_fold)
    gnn_val_fold = np.exp(-dist_mat_val_fold/10)
    # Re-initialize and compile the model for each fold
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(BUFFER_METERS / res_x)
        patch_width = 2 * buffer_pixels_x
        cnn_patch_shape = (patch_width, patch_width, len(raster_paths))
    
    model = build_moe_model(cnn_patch_shape, len(coords_train_fold), mlp_train_fold.shape[1])
    
    # Create Data Generators for the current fold
    train_generator = DataGenerator(
        coords=coords_train_fold,
        mlp_data=mlp_train_scaled,
        gnn_data=gnn_train_fold,
        y=y_train_fold,
        raster_paths=raster_paths,
        buffer_meters=BUFFER_METERS,
        batch_size=batch_size,
        shuffle=True
    )
    
    val_generator = DataGenerator(
        coords=coords_val_fold,
        mlp_data=mlp_val_scaled,
        gnn_data=gnn_val_fold,
        y=y_val_fold,
        raster_paths=raster_paths,
        buffer_meters=BUFFER_METERS,
        batch_size=batch_size,
        shuffle=False
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    )
    print("Starting model training...")
    history = model.fit(
        train_generator,
        epochs=100,
        verbose=1,
        callbacks=[early_stopping],
        validation_data=val_generator
    )
    print("Training complete.")
    # Evaluate on the validation set
    fold_result = evaluate_model(model, coords_val_fold, mlp_val_scaled, gnn_val_fold, y_val_fold, raster_paths, BUFFER_METERS, batch_size)
    fold_metrics.append(fold_result)
    print("Validation Metrics:")
    print(f"R²: {fold_result['R2']:.4f}")
    print(f"MAE: {fold_result['MAE']:.4f}")
    print(f"RMSE: {fold_result['RMSE']:.4f}")
    print(f"SMAPE: {fold_result['SMAPE']:.4f}%")
    # Prepare and evaluate on the independent test set
    dist_mat_test_train = distance_matrix(test_orig[['Long','Lat']].values, coords_train_fold)
    gnn_test_fold = np.exp(-dist_mat_test_train/10)
    
    mlp_test_scaled = scaler_fold.transform(test_orig[numeric_cols].values)
    test_result = evaluate_model(model, test_orig[['Long','Lat']].values, mlp_test_scaled, gnn_test_fold, test_orig['RI'].values, raster_paths, BUFFER_METERS, batch_size)
    test_metrics.append(test_result)
    print("\nIndependent Test Set Metrics:")
    print(f"R²: {test_result['R2']:.4f}")
    print(f"MAE: {test_result['MAE']:.4f}")
    print(f"RMSE: {test_result['RMSE']:.4f}")
    print(f"SMAPE: {test_result['SMAPE']:.4f}%")
    
    # Save the trained model for the current fold
    model.save(os.path.join(model_save_dir, f"fold_{fold+1}.keras"))
    print(f"Model for Fold {fold+1} saved.")
    # Clean up to free memory
    del model, train_generator, val_generator, early_stopping, history
    tf.keras.backend.clear_session()
    gc.collect()
# ==================== 7. Print Final Averages ==================== #
# Calculate average metrics
avg_fold_metrics = pd.DataFrame(fold_metrics).mean().to_dict()
avg_test_metrics = pd.DataFrame(test_metrics).mean().to_dict()
print("\n" + "="*80)
print("Final Average Metrics Across All Folds:")
print("="*80)
print("\nAverage Validation Metrics:")
print(f"Average R²: {avg_fold_metrics['R2']:.4f}")
print(f"Average MAE: {avg_fold_metrics['MAE']:.4f}")
print(f"Average RMSE: {avg_fold_metrics['RMSE']:.4f}")
print(f"Average SMAPE: {avg_fold_metrics['SMAPE']:.4f}%")
print("\nAverage Independent Test Set Metrics:")
print(f"Average R²: {avg_test_metrics['R2']:.4f}")
print(f"Average MAE: {avg_test_metrics['MAE']:.4f}")
print(f"Average RMSE: {avg_test_metrics['RMSE']:.4f}")
print(f"Average SMAPE: {avg_test_metrics['SMAPE']:.4f}%")


Using 26 raster layers for CNN input.
  - bui.tif
  - ndsi.tif
  - savi.tif
  - ndbsi.tif
  - ui.tif
  - ndwi.tif
  - ndbi.tif
  - awei.tif
  - evi.tif
  - mndwi.tif
  - ndvi.tif
  - LULC2020.tif
  - LULC2021.tif
  - LULC2022.tif
  - LULC2019.tif
  - LULC2018.tif
  - LULC2017.tif
  - Pb_R.tif
  - ClayR.tif
  - SandR.tif
  - CdR.tif
  - CrR.tif
  - AsR.tif
  - SiltR.tif
  - CuR.tif
  - NiR.tif
Starting 5-Fold Cross-Validation...
Models will be saved in: 'models/mixture_of_experts'

--- Fold 1/5 ---
Starting model training...
Epoch 1/100
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 217ms/step - loss: 168551.1094 - val_loss: 43592.1172
Epoch 2/100
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 183ms/step - loss: 30601.4355 - val_loss: 9954.8633
Epoch 3/100
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 189ms/step - loss: 7955.0679 - val_loss: 8405.6396
Epoch 4/100
[1m22/22[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 196ms/ste

In [2]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Concatenate,
    Dropout,
    Layer,
    Lambda
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import warnings # For LIME
from io import StringIO
import pickle # For saving dictionaries and other objects

from lime.lime_tabular import LimeTabularExplainer

# Define the single buffer size to use
BUFFER_METERS = 500

# ==================== 1. Load Data ==================== #
# Load original data
orig = pd.read_csv("../../data/WinterSeason1.csv")
river_100 = pd.read_csv("../data/Samples_100W.csv")

# Define the columns to drop and the numeric columns to use for MLP
drop_cols = ['Stations', 'River', 'Lat', 'Long', 'geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# Ensure there are no NaNs in the numeric columns before proceeding
orig[numeric_cols] = orig[numeric_cols].fillna(0)
river_100[numeric_cols] = river_100[numeric_cols].fillna(0)

# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)

# ==================== 2. Collect ALL Rasters ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDWW/*.tif")

print(f"Using {len(raster_paths)} raster layers for CNN input.")
for r in raster_paths:
    print("  -", os.path.basename(r))

# ==================== 3. Define Metric Functions ==================== #

def smape(y_true, y_pred):
    """
    Calculates the Symmetric Mean Absolute Percentage Error (SMAPE).
    """
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero by setting the result to 0 where denominator is 0
    return np.mean(np.where(denominator == 0, 0, numerator / denominator)) * 100

# ==================== 4. Create a Custom Data Generator ==================== #
def extract_patch_for_generator(coords, raster_files, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height):
    """
    Extracts a batch of patches from rasters for a given set of coordinates.
    This function is optimized to be called by the data generator for each batch.
    """
    patches = []
    # Loop through each coordinate pair in the batch
    for lon, lat in coords:
        channels = []
        # Loop through each raster file to get a single patch for each raster
        for rfile in raster_files:
            with rasterio.open(rfile) as src:
                try:
                    row, col = src.index(lon, lat)
                    win = Window(col - buffer_pixels_x, row - buffer_pixels_y, patch_width, patch_height)
                    arr = src.read(1, window=win, boundless=True, fill_value=0)
                    
                    # Corrected logic: Convert any NaNs to a numerical value, e.g., 0,
                    # to prevent them from propagating through the model.
                    arr = np.nan_to_num(arr, nan=0.0)
                    arr = arr.astype(np.float32)

                    # Get the maximum value, but check if it's a valid number and > 0.
                    # This prevents division by zero if the patch is all zeros.
                    max_val = np.nanmax(arr)
                    if np.isfinite(max_val) and max_val > 0:
                        arr /= max_val
                        
                except Exception as e:
                    print(f"Error processing {rfile} for coordinates ({lon}, {lat}): {e}")
                    arr = np.zeros((patch_width, patch_height), dtype=np.float32)
            channels.append(arr)
        patches.append(np.stack(channels, axis=-1))
    
    return np.array(patches)

class DataGenerator(Sequence):
    def __init__(self, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.coords = coords
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.raster_paths = raster_paths
        self.buffer_meters = buffer_meters
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))

        # Pre-calculate patch size from the first raster
        with rasterio.open(raster_paths[0]) as src:
            res_x, res_y = src.res
            self.buffer_pixels_x = int(self.buffer_meters / res_x)
            self.buffer_pixels_y = int(self.buffer_meters / res_y)
            self.patch_width = 2 * self.buffer_pixels_x
            self.patch_height = 2 * self.buffer_pixels_y

        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_coords = self.coords[batch_indices]
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]

        # Extract CNN patches for the current batch
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            self.raster_paths,
            self.buffer_pixels_x,
            self.buffer_pixels_y,
            self.patch_width,
            self.patch_height
        )

        # Return a tuple of inputs and the target, which Keras expects
        return (batch_cnn, batch_mlp, batch_gnn), batch_y

# ==================== 5. Define the Mixture of Experts Model ==================== #
def build_moe_model(patch_shape, gnn_dim, mlp_dim):
    # Inputs for all branches
    cnn_input = Input(shape=patch_shape, name="cnn_input")
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- Expert 1: CNN Branch ---
    cnn_branch = Conv2D(32, (3,3), activation="relu", padding="same")(cnn_input)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch = Conv2D(64, (3,3), activation="relu", padding="same")(cnn_branch)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch_flattened = Flatten()(cnn_branch)
    cnn_branch_dense = Dense(128, activation="relu")(cnn_branch_flattened)
    cnn_expert_out = Dense(1, activation="linear", name="cnn_expert_out")(cnn_branch_dense)

    # --- Expert 2: MLP Branch ---
    mlp_branch = Dense(64, activation="relu")(mlp_input)
    mlp_branch = Dense(32, activation="relu")(mlp_branch)
    mlp_expert_out = Dense(1, activation="linear", name="mlp_expert_out")(mlp_branch)

    # --- Expert 3: GNN Branch ---
    gnn_branch = Dense(64, activation="relu")(gnn_input)
    gnn_branch = Dense(32, activation="relu")(gnn_branch)
    gnn_expert_out = Dense(1, activation="linear", name="gnn_expert_out")(gnn_branch)

    # --- Gating Network ---
    gate_input = Concatenate()([cnn_branch_dense, mlp_branch, gnn_branch])
    gate_network = Dense(64, activation="relu")(gate_input)
    gate_network = Dense(32, activation="relu")(gate_network)
    gate_weights = Dense(3, activation="softmax", name="gate_weights")(gate_network)

    # --- Combine Experts and Gating Network ---
    experts_stack = Concatenate(axis=1, name="experts_stack")([cnn_expert_out, mlp_expert_out, gnn_expert_out])
    final_output = Lambda(lambda x: tf.reduce_sum(x[0] * x[1], axis=1, keepdims=True), name="final_output")([experts_stack, gate_weights])

    # Build and compile the model
    model = Model(inputs=[cnn_input, mlp_input, gnn_input], outputs=final_output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

def evaluate_model(model, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4):
    """
    Evaluates a model on the given data and returns a dictionary of metrics.
    """
    num_samples = len(y)
    y_pred_list = []
    
    # Get patch dimensions
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(buffer_meters / res_x)
        buffer_pixels_y = int(buffer_meters / res_y)
        patch_width = 2 * buffer_pixels_x
        patch_height = 2 * buffer_pixels_y

    for i in range(0, num_samples, batch_size):
        batch_coords = coords[i:i+batch_size]
        batch_mlp = mlp_data[i:i+batch_size]
        batch_gnn = gnn_data[i:i+batch_size, :]
        
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            raster_paths,
            buffer_pixels_x,
            buffer_pixels_y,
            patch_width,
            patch_height
        )
        
        y_pred_list.append(model.predict((batch_cnn, batch_mlp, batch_gnn), verbose=0).flatten())
        
    y_pred = np.concatenate(y_pred_list)
    
    r2 = r2_score(y, y_pred)
    mae = mean_absolute_error(y, y_pred)
    rmse = np.sqrt(mean_squared_error(y, y_pred))
    smape_val = smape(y, y_pred)

    return {
        'R2': r2,
        'MAE': mae,
        'RMSE': rmse,
        'SMAPE': smape_val
    }

# ==================== 6. Main Script: Train, Evaluate, and Explain ==================== #

print("="*80)
print("Starting Single-Run Training and Evaluation...")
print("="*80)

# Prepare data for training and validation
coords_train_all = train_combined[['Long', 'Lat']].values
mlp_train_all = train_combined[numeric_cols].values
y_train_all = train_combined['RI'].values
batch_size = 4

# Create a train/validation split from the training data
coords_train, coords_val, mlp_train, mlp_val, y_train, y_val = train_test_split(
    coords_train_all, mlp_train_all, y_train_all, test_size=0.2, random_state=42
)

# Scale MLP data and prepare GNN matrices for training
scaler = StandardScaler()
mlp_train_scaled = scaler.fit_transform(mlp_train)
mlp_val_scaled = scaler.transform(mlp_val)

dist_mat_train = distance_matrix(coords_train, coords_train)
gnn_train = np.exp(-dist_mat_train / 10)

dist_mat_val = distance_matrix(coords_val, coords_train)
gnn_val = np.exp(-dist_mat_val / 10)

# Build and compile the model
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    patch_width = 2 * buffer_pixels_x
    cnn_patch_shape = (patch_width, patch_width, len(raster_paths))

model = build_moe_model(cnn_patch_shape, gnn_train.shape[1], mlp_train_scaled.shape[1])

# Create Data Generators
train_generator = DataGenerator(
    coords=coords_train,
    mlp_data=mlp_train_scaled,
    gnn_data=gnn_train,
    y=y_train,
    raster_paths=raster_paths,
    buffer_meters=BUFFER_METERS,
    batch_size=batch_size,
    shuffle=True
)

val_generator = DataGenerator(
    coords=coords_val,
    mlp_data=mlp_val_scaled,
    gnn_data=gnn_val,
    y=y_val,
    raster_paths=raster_paths,
    buffer_meters=BUFFER_METERS,
    batch_size=batch_size,
    shuffle=False
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True
)

print("Starting model training...")
history = model.fit(
    train_generator,
    epochs=100,
    verbose=0,
    callbacks=[early_stopping],
    validation_data=val_generator
)
print("Training complete.")

# --- Evaluation on Independent Test Set ---
print("\n" + "="*80)
print("Evaluating on Independent Test Set...")
print("="*80)
coords_test = test_orig[['Long','Lat']].values
mlp_test = scaler.transform(test_orig[numeric_cols].values)
dist_mat_test_train = distance_matrix(coords_test, coords_train)
gnn_test = np.exp(-dist_mat_test_train/10)
y_test = test_orig['RI'].values

test_result = evaluate_model(model, coords_test, mlp_test, gnn_test, y_test, raster_paths, BUFFER_METERS, batch_size)


Using 26 raster layers for CNN input.
  - bui.tif
  - ndsi.tif
  - savi.tif
  - ndbsi.tif
  - ui.tif
  - ndwi.tif
  - ndbi.tif
  - awei.tif
  - evi.tif
  - mndwi.tif
  - ndvi.tif
  - LULC2020.tif
  - LULC2021.tif
  - LULC2022.tif
  - LULC2019.tif
  - LULC2018.tif
  - LULC2017.tif
  - ClayW.tif
  - CdW.tif
  - SandW.tif
  - SiltW.tif
  - AsW.tif
  - CrW.tif
  - NiW.tif
  - PbW.tif
  - CuW.tif
Starting Single-Run Training and Evaluation...
Starting model training...
Training complete.

Evaluating on Independent Test Set...


In [3]:
print("\nIndependent Test Set Metrics:")
print(f"R²: {test_result['R2']:.4f}")
print(f"MAE: {test_result['MAE']:.4f}")
print(f"RMSE: {test_result['RMSE']:.4f}")
print(f"SMAPE: {test_result['SMAPE']:.4f}%")



Independent Test Set Metrics:
R²: 0.7706
MAE: 29.3151
RMSE: 37.8735
SMAPE: 17.5464%


In [4]:
model.save("MixtureOfExperts.keras")


In [5]:
# ==================== 7. Feature Importance Analysis ==================== #

# --- Permutation Feature Importance (MLP Features only) ---
print("\n" + "="*80)
print("Calculating Permutation Feature Importance...")
print("="*80)

# Get baseline performance on the test set
baseline_r2 = test_result['R2']
importance = {}

for col in numeric_cols:
    print(f"Permuting feature: {col}")
    
    # Create a copy of the scaled test data
    mlp_test_shuffled = mlp_test.copy()
    
    # Find the index of the column to shuffle
    col_idx = list(numeric_cols).index(col)
    
    # Shuffle the values of the specific column
    np.random.shuffle(mlp_test_shuffled[:, col_idx])
    
    # Evaluate the model with the shuffled data
    shuffled_result = evaluate_model(model, coords_test, mlp_test_shuffled, gnn_test, y_test, raster_paths, BUFFER_METERS, batch_size)
    shuffled_r2 = shuffled_result['R2']
    
    # Calculate the importance as the drop in R²
    importance[col] = baseline_r2 - shuffled_r2

# Create a DataFrame for permutation importance
permutation_df = pd.DataFrame.from_dict(
    importance,
    orient='index',
    columns=['Permutation Importance (R² drop)']
).sort_values(by='Permutation Importance (R² drop)', ascending=False)
print("\nPermutation Feature Importance (MLP Features):")
print(permutation_df)

# --- LIME (Local Interpretable Model-Agnostic Explanations) ---
print("\n" + "="*80)
print("Calculating LIME Feature Importance for a few samples...")
print("="*80)

warnings.filterwarnings("ignore", category=UserWarning)

# Get CNN and GNN data for the entire test set once
print("Pre-loading CNN patches and GNN matrix for LIME...")
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    buffer_pixels_y = int(BUFFER_METERS / res_y)
    patch_width = 2 * buffer_pixels_x
    patch_height = 2 * buffer_pixels_y
    
cnn_data_test = extract_patch_for_generator(
    coords_test,
    raster_paths,
    buffer_pixels_x,
    buffer_pixels_y,
    patch_width,
    patch_height
)
gnn_data_test = np.exp(-distance_matrix(coords_test, coords_train) / 10)


# LIME explainer setup
explainer = LimeTabularExplainer(
    training_data=mlp_test,
    feature_names=list(numeric_cols),
    class_names=['RI'],
    mode='regression'
)

# Select a few random samples from the test set for explanation
np.random.seed(42)
sample_indices = np.random.choice(len(test_orig), 5, replace=False)

lime_results = []
for i in sample_indices:
    print(f"\nExplaining sample at index {i}...")
    
    # Define a local prediction function for LIME
    def predict_wrapper(instances):
        # LIME provides a batch of perturbed tabular data (instances).
        # We need to replicate the CNN and GNN data for this batch.
        num_instances = instances.shape[0]
        
        # Replicate the CNN patch for the current sample
        cnn_batch = np.tile(cnn_data_test[i:i+1], (num_instances, 1, 1, 1))
        
        # Replicate the GNN row for the current sample
        gnn_batch = np.tile(gnn_data_test[i:i+1], (num_instances, 1))
        
        # Predict on the new batch
        predictions = model.predict((cnn_batch, instances, gnn_batch), verbose=0)
        return np.array(predictions).reshape(-1, 1)

    # Explain the prediction for a single instance using the wrapper
    explanation = explainer.explain_instance(
        data_row=mlp_test[i],
        predict_fn=predict_wrapper,
        num_features=len(numeric_cols)
    )
    
    # Get the feature importances from the explanation
    feat_imps = explanation.as_list()
    
    # Append to our list of results
    for feature, importance in feat_imps:
        lime_results.append({
            'Sample Index': i,
            'Feature': feature,
            'LIME Weight': importance
        })

# Create a DataFrame for LIME results
lime_df = pd.DataFrame(lime_results)
print("\nLIME Local Feature Importance for 5 Test Samples:")
print(lime_df)

warnings.filterwarnings("default", category=UserWarning)



Calculating Permutation Feature Importance...
Permuting feature: hydro_dist_brick
Permuting feature: num_brick_field
Permuting feature: hydro_dist_ind
Permuting feature: num_industry
Permuting feature: CrW
Permuting feature: NiW
Permuting feature: CuW
Permuting feature: AsW
Permuting feature: CdW
Permuting feature: PbW
Permuting feature: MW
Permuting feature: SandW
Permuting feature: SiltW
Permuting feature: ClayW
Permuting feature: FeW

Permutation Feature Importance (MLP Features):
                  Permutation Importance (R² drop)
hydro_dist_brick                               0.0
num_brick_field                                0.0
hydro_dist_ind                                 0.0
num_industry                                   0.0
CrW                                            0.0
NiW                                            0.0
CuW                                            0.0
AsW                                            0.0
CdW                                            0.0


In [6]:
lime_df.sort_values("LIME Weight").head(10)


Unnamed: 0,Sample Index,Feature,LIME Weight
0,0,hydro_dist_brick > -1.10,0.0
53,2,CdW <= -0.81,0.0
52,2,-0.71 < AsW <= 0.80,0.0
51,2,-0.13 < CuW <= 0.93,0.0
50,2,NiW <= -0.34,0.0
49,2,CrW > 1.00,0.0
48,2,3.64 < num_industry <= 7.10,0.0
47,2,-0.87 < hydro_dist_ind <= -0.87,0.0
46,2,num_brick_field > 7.23,0.0
45,2,hydro_dist_brick <= -1.10,0.0


In [7]:
permutation_df


Unnamed: 0,Permutation Importance (R² drop)
hydro_dist_brick,0.0
num_brick_field,0.0
hydro_dist_ind,0.0
num_industry,0.0
CrW,0.0
NiW,0.0
CuW,0.0
AsW,0.0
CdW,0.0
PbW,0.0
