In [None]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout, Layer, MultiHeadAttention, LayerNormalization, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
from io import StringIO
import pickle # Import the pickle library for saving objects

# Define the single buffer size to use
BUFFER_METERS = 500

# ==================== 1. Load Data ==================== #
orig = pd.read_csv("../../data/RainySeason.csv")
river_100 = pd.read_csv("../data/Samples_100.csv")

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)

# ==================== 2. Collect ALL Rasters ==================== #
# We are not using rasters in this GNN-MLP model, but the paths are still
# defined for consistency with previous versions.
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")

print("Note: Raster data is not used in this GNN-MLP model.")

# ==================== 3. Create a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    def __init__(self, mlp_data, gnn_data, y, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        
        return (batch_mlp, batch_gnn), batch_y

# ==================== 4. Prepare GNN & MLP Input (only once) ==================== #
coords_train = train_combined[['Long','Lat']].values
coords_test = test_orig[['Long','Lat']].values
dist_mat_train = distance_matrix(coords_train, coords_train)
gnn_train = np.exp(-dist_mat_train/10)
dist_mat_test_train = distance_matrix(coords_test, coords_train)
gnn_test = np.exp(-dist_mat_test_train/10)

scaler = StandardScaler()
mlp_train = scaler.fit_transform(train_combined[numeric_cols])
mlp_test = scaler.transform(test_orig[numeric_cols])
y_train = train_combined['RI'].values
y_test = test_orig['RI'].values

# ==================== 5. Define GNN-MLP Fusion Model ==================== #
def build_gnn_mlp_model(mlp_dim, gnn_dim):
    # Inputs for all branches
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- MLP Branch ---
    mlp_embedding = Dense(128, activation="relu")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch ---
    gnn_embedding = Dense(128, activation="relu")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)

    # --- Concatenate Embeddings ---
    combined = Concatenate()([mlp_embedding, gnn_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[mlp_input, gnn_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

def evaluate_model(model, mlp_test, gnn_test_matrix, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, and predictions.
    """
    y_pred = model.predict((mlp_test, gnn_test_matrix)).flatten()
    
    if return_preds:
        return y_pred
    else:
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        return r2, rmse

def calculate_permutation_importance(model, mlp_data, gnn_data, y_true):
    """
    Calculates permutation feature importance for the MLP and GNN branches.
    """
    print("\nStarting Permutation Feature Importance Analysis...")
    # Get baseline R² on the unshuffled data
    baseline_r2, _ = evaluate_model(model, mlp_data, gnn_data, y_true)
    print(f"Baseline R² on test set: {baseline_r2:.4f}")

    importance = {}
    
    # Permute MLP input
    shuffled_mlp_data = mlp_data.copy()
    np.random.shuffle(shuffled_mlp_data)
    shuffled_r2, _ = evaluate_model(model, shuffled_mlp_data, gnn_data, y_true)
    importance['MLP'] = baseline_r2 - shuffled_r2

    # Permute GNN input
    shuffled_gnn_data = gnn_data.copy()
    np.random.shuffle(shuffled_gnn_data)
    shuffled_r2, _ = evaluate_model(model, mlp_data, shuffled_gnn_data, y_true)
    importance['GNN'] = baseline_r2 - shuffled_r2

    return importance
        
# ==================== Run the Analysis ==================== #
# Redirect output to a string for later saving
old_stdout = sys.stdout
sys.stdout = captured_output = StringIO()

print("\n" + "="*80)
print(f"Analyzing GNN-MLP Fusion Model")
print("="*80)

batch_size = 4
gnn_input_dim = len(coords_train)
mlp_input_dim = mlp_train.shape[1]

model = build_gnn_mlp_model(mlp_input_dim, gnn_input_dim)
model.summary()

# ==================== 6. Create Data Generators ==================== #
train_generator = DataGenerator(
    mlp_data=mlp_train, gnn_data=gnn_train, y=y_train,
    batch_size=batch_size, shuffle=True
)

# ==================== 7. Train Model ==================== #
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

history = model.fit(
    train_generator,
    epochs=100,
    verbose=1,
    callbacks=[early_stopping],
    validation_data=train_generator
)

# ==================== 8. Evaluate & Perform Feature Importance ==================== #
# Predict on the training data using the generator
y_pred_train = model.predict(train_generator).flatten()
r2_train = r2_score(y_train[:len(y_pred_train)], y_pred_train)
rmse_train = np.sqrt(mean_squared_error(y_train[:len(y_pred_train)], y_pred_train))

# Evaluate on the test data using the updated function
r2_test, rmse_test = evaluate_model(model, mlp_test, gnn_test, y_test)
y_pred_test = evaluate_model(model, mlp_test, gnn_test, y_test, return_preds=True)

print(f"\n GNN-MLP Fusion Model Performance:")
print(f"R² Train: {r2_train:.4f} | RMSE Train: {rmse_train:.4f}")
print(f"R² Test: {r2_test:.4f} | RMSE Test: {rmse_test:.4f}")

# Calculate and print feature importance
feature_importance = calculate_permutation_importance(model, mlp_test, gnn_test, y_test)
print("\n--- Feature Importance (Permutation) ---")
sorted_importance = sorted(feature_importance.items(), key=lambda item: item[1], reverse=True)
for feature, score in sorted_importance:
    print(f"{feature}: {score:.4f}")

# ==================== 9. Save all info to a folder ==================== #
# Restore standard output
sys.stdout = old_stdout
printed_output = captured_output.getvalue()

output_folder = "gnn_mlp"
os.makedirs(output_folder, exist_ok=True)
print(f"\nCreating folder: '{output_folder}' and saving results...")

# Save the model
model_path = os.path.join(output_folder, "gnn_mlp_model.keras")
model.save(model_path)
print(f"Model saved to: {model_path}")

# Save the predictions and true labels
np.save(os.path.join(output_folder, "y_train.npy"), y_train)
np.save(os.path.join(output_folder, "y_test.npy"), y_test)
np.save(os.path.join(output_folder, "y_pred_train.npy"), y_pred_train)
np.save(os.path.join(output_folder, "y_pred_test.npy"), y_pred_test)
print(f"Predictions and true labels saved as .npy files.")

# Save the printed output to a text file
output_path = os.path.join(output_folder, "analysis_output.txt")
with open(output_path, "w") as f:
    f.write(printed_output)
print(f"Analysis results saved to: {output_path}")

# Save the feature importance dictionary as a .pkl file
importance_path = os.path.join(output_folder, "feature_importance.pkl")
with open(importance_path, 'wb') as f:
    pickle.dump(feature_importance, f)
print(f"Feature importance results saved to: {importance_path}")

print("\nAll information successfully saved.")

# Garbage collect to free up memory now that everything is saved
del model, history, train_generator
gc.collect()

In [1]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout, Layer, LayerNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import pickle # Import the pickle library for saving objects

# Set a consistent seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# ==================== 1. Load Data ==================== #
# NOTE: This script assumes the following file paths are correct.
try:
    orig = pd.read_csv("../../data/WinterSeason1.csv")
    river_100 = pd.read_csv("../data/Samples_100W.csv")
except FileNotFoundError as e:
    print(f"Error: Required data file not found. Please check your file paths.")
    print(f"Details: {e}")
    sys.exit()

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# ==================== 2. Collect ALL Rasters and Metadata ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")

# Get the pixel resolution from the first raster to set a uniform patch size
try:
    with rasterio.open(raster_paths[0]) as src:
        pixel_size = src.transform.a
except IndexError:
    print("Error: No raster files found in the specified directories.")
    sys.exit()

# Create a dictionary to store raster metadata for fast access
raster_metadata = {}
for path in raster_paths:
    with rasterio.open(path) as src:
        raster_metadata[path] = {
            'transform': src.transform,
            'crs': src.crs,
            'width': src.width,
            'height': src.height
        }

# ==================== 3. Define a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    """
    Custom Keras Sequence for generating batches of data.
    Handles three different input types: MLP features, GNN features,
    and raster image patches, loading rasters on-the-fly to save memory.
    """
    def __init__(self, mlp_data, gnn_data, y, coords, raster_paths, buffer_radius_m, pixel_size, batch_size=4, shuffle=True):
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.coords = coords
        self.raster_paths = raster_paths
        # Calculate the uniform patch size in pixels based on the buffer radius and pixel size
        # We need a square patch, so the size is 2 * radius / pixel_size
        self.patch_size = int(round((2 * buffer_radius_m) / pixel_size))
        # Ensure patch size is at least 1 and is an even number for easy centering
        if self.patch_size % 2 != 0:
            self.patch_size += 1
        self.patch_size = max(self.patch_size, 2)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
        
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def get_raster_patches(self, coords_batch):
        """
        Extracts a patch of raster data for each coordinate in the batch.
        Loads rasters on-the-fly to save memory and robustly handles boundaries.
        """
        patches_for_rasters = []
        for path in self.raster_paths:
            patches_for_this_raster = []
            try:
                with rasterio.open(path) as src:
                    for lon, lat in coords_batch:
                        # Get pixel coordinates
                        row, col = src.index(lon, lat)
                        
                        # Define a window to read around the pixel, handling boundaries
                        half_patch = self.patch_size // 2
                        left = int(col - half_patch)
                        top = int(row - half_patch)
                        right = int(col + half_patch)
                        bottom = int(row + half_patch)

                        # Create a new, empty array for the final padded patch
                        padded_patch = np.zeros((self.patch_size, self.patch_size), dtype='float32')

                        # Calculate the window in the raster's coordinate space to read from
                        # And the offset in the padded_patch to write to
                        read_left = max(0, left)
                        read_top = max(0, top)
                        read_right = min(src.width, right)
                        read_bottom = min(src.height, bottom)

                        # Check if the calculated window has a valid size
                        read_width = read_right - read_left
                        read_height = read_bottom - read_top
                        
                        if read_width > 0 and read_height > 0:
                            write_left = read_left - left
                            write_top = read_top - top
                            write_right = write_left + read_width
                            write_bottom = write_top + read_height

                            # Create the window object for rasterio to read from
                            window = Window(read_left, read_top, read_width, read_height)

                            # Read the data from the raster
                            patch_data = src.read(1, window=window)
                            # Place the read data into the padded patch
                            padded_patch[write_top:write_bottom, write_left:write_right] = patch_data
                        
                        patches_for_this_raster.append(padded_patch)
            
                # Stack the patches for this raster
                patches_for_rasters.append(np.stack(patches_for_this_raster, axis=0))
            except Exception as e:
                # This handles cases where a raster file might be missing or corrupted
                patches_for_rasters.append(np.zeros((len(coords_batch), self.patch_size, self.patch_size), dtype='float32'))


        # Stack all raster patches together
        final_patches = np.stack(patches_for_rasters, axis=-1)
        return final_patches

    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        batch_coords = self.coords[batch_indices]
        
        # Get raster data for the current batch
        batch_rasters = self.get_raster_patches(batch_coords)
        
        # Return a dictionary of inputs and the output
        return {"mlp_input": batch_mlp, "gnn_input": batch_gnn, "raster_input": batch_rasters}, batch_y

# ==================== 4. Define GNN-MLP-Raster Fusion Model ==================== #
def build_fusion_model(mlp_dim, gnn_dim, raster_patch_size, num_rasters):
    """
    Builds the multi-input Keras model with branches for MLP, GNN, and Rasters.
    """
    # Inputs for all branches
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    raster_input = Input(shape=(raster_patch_size, raster_patch_size, num_rasters), name="raster_input")

    # --- MLP Branch ---
    mlp_embedding = Dense(128, activation="relu")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch ---
    gnn_embedding = Dense(128, activation="relu")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)
    
    # --- Raster Branch (using a simple CNN) ---
    raster_conv = Conv2D(32, (3, 3), activation="relu")(raster_input)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_conv = Conv2D(64, (3, 3), activation="relu")(raster_pool)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_flatten = Flatten()(raster_pool)
    raster_embedding = Dense(64, activation="relu", name="raster_embedding")(raster_flatten)

    # --- Concatenate Embeddings ---
    combined = Concatenate()([mlp_embedding, gnn_embedding, raster_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[mlp_input, gnn_input, raster_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

# ==================== 5. Define Evaluation & Importance Functions ==================== #
def calculate_smape(y_true, y_pred):
    """Calculates Symmetric Mean Absolute Percentage Error (SMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero
    mask = denominator == 0
    smape_val = np.where(mask, 0, numerator / denominator)
    return 100 * np.mean(smape_val)

def evaluate_model(model, data_inputs, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, MAE, and SMAPE.
    Handles both Keras Generators and direct numpy arrays.
    """
    if isinstance(data_inputs, DataGenerator):
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    else:
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    
    if return_preds:
        return y_pred
    else:
        # Align true labels with predictions if using a generator
        y_true_aligned = y_test[:len(y_pred)]
        r2 = r2_score(y_true_aligned, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true_aligned, y_pred))
        mae = mean_absolute_error(y_true_aligned, y_pred)
        smape = calculate_smape(y_true_aligned, y_pred)
        return r2, rmse, mae, smape

def calculate_permutation_importance(model, mlp_data, gnn_data, raster_data, y_true, mlp_features, raster_features):
    """
    Calculates permutation feature importance for all individual features.
    """
    print("\nStarting Permutation Feature Importance Analysis...")
    
    # Create the combined input for the model
    initial_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
    
    # Get baseline R² on the unshuffled data
    baseline_r2, _, _, _ = evaluate_model(model, initial_inputs, y_true)
    print(f"Baseline R²: {baseline_r2:.4f}")
    
    importance = {}
    
    # 1. Permute individual MLP features
    print("Permuting MLP features...")
    for i, feature in enumerate(mlp_features):
        shuffled_mlp_data = mlp_data.copy()
        np.random.shuffle(shuffled_mlp_data[:, i])
        shuffled_inputs = {"mlp_input": shuffled_mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'MLP_{feature}'] = baseline_r2 - shuffled_r2
    
    # 2. Permute GNN input
    print("Permuting GNN features...")
    shuffled_gnn_data = gnn_data.copy()
    np.random.shuffle(shuffled_gnn_data)
    shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": shuffled_gnn_data, "raster_input": raster_data}
    shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
    importance['GNN'] = baseline_r2 - shuffled_r2
    
    # 3. Permute Raster inputs
    print("Permuting Raster features...")
    for i, feature in enumerate(raster_features):
        shuffled_raster_data = raster_data.copy()
        # Shuffle a single channel (raster band)
        shuffled_raster_data[:, :, :, i] = np.random.permutation(shuffled_raster_data[:, :, :, i].flatten()).reshape(shuffled_raster_data.shape[0], shuffled_raster_data.shape[1], shuffled_raster_data.shape[2])
        shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": shuffled_raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'Raster_{os.path.basename(feature)}'] = baseline_r2 - shuffled_r2
        
    return importance

# ==================== 6. Main Analysis with K-Fold CV ==================== #

print("\n" + "="*80)
print(f"Analyzing GNN-MLP-Raster Fusion Model with 5-Fold Cross-Validation")
print(f"Using a uniform patch size of {int(round((2 * 500) / pixel_size))} pixels for a 500m buffer.")
print("="*80)

# Combine all data for K-Fold splitting
full_data = pd.concat([orig, river_100], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
full_coords = full_data[['Long','Lat']].values
full_y = full_data['RI'].values
full_mlp_data = full_data[numeric_cols].values
full_raster_data = full_coords # This will be processed by the generator

# Pre-process MLP data with StandardScaler
scaler = StandardScaler()
full_mlp_data = scaler.fit_transform(full_mlp_data)

# K-Fold setup
n_splits = 5
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
fold_results = []
all_feature_importances = {}
buffer_radius_m = 500
raster_patch_size = int(round((2 * buffer_radius_m) / pixel_size))
if raster_patch_size % 2 != 0:
    raster_patch_size += 1
raster_patch_size = max(raster_patch_size, 2)
num_rasters = len(raster_paths)

for fold, (train_index, test_index) in enumerate(kf.split(full_data)):
    print(f"\n--- Starting Fold {fold+1}/{n_splits} ---")
    
    # Get train and test data for this fold
    train_mlp, test_mlp = full_mlp_data[train_index], full_mlp_data[test_index]
    train_coords, test_coords = full_coords[train_index], full_coords[test_index]
    y_train, y_test = full_y[train_index], full_y[test_index]
    
    # Prepare GNN input (adjacency matrix based on distances)
    dist_mat_train = distance_matrix(train_coords, train_coords)
    gnn_train = np.exp(-dist_mat_train / 10)
    
    dist_mat_test_train = distance_matrix(test_coords, train_coords)
    gnn_test = np.exp(-dist_mat_test_train / 10)

    # Clean up memory
    del dist_mat_train, dist_mat_test_train
    gc.collect()

    # Re-build and compile the model for each fold
    model = build_fusion_model(mlp_dim=train_mlp.shape[1], gnn_dim=gnn_train.shape[1], 
                               raster_patch_size=raster_patch_size, num_rasters=num_rasters)
    
    if fold == 0:
        model.summary()
    
    # Create data generators
    train_generator = DataGenerator(
        mlp_data=train_mlp, gnn_data=gnn_train, y=y_train, coords=train_coords,
        raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=True
    )

    test_generator = DataGenerator(
        mlp_data=test_mlp, gnn_data=gnn_test, y=y_test, coords=test_coords,
        raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=False
    )
    
    # Train the model
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    )

    history = model.fit(
        train_generator,
        epochs=100,
        verbose=1,
        callbacks=[early_stopping],
        validation_data=test_generator
    )

    # Evaluate on the test data
    r2_test, rmse_test, mae_test, smape_test = evaluate_model(model, test_generator, y_test)
    fold_results.append({'R2': r2_test, 'RMSE': rmse_test, 'MAE': mae_test, 'SMAPE': smape_test})
    
    print(f"Fold {fold+1} Test Metrics:")
    print(f"R²: {r2_test:.4f} | RMSE: {rmse_test:.4f} | MAE: {mae_test:.4f} | SMAPE: {smape_test:.4f}%")

    # Calculate and store feature importance for this fold
    # Get all test data as numpy arrays for importance calculation
    test_mlp_full = test_generator.mlp_data
    test_gnn_full = test_generator.gnn_data
    test_y_full = test_generator.y
    test_coords_full = test_generator.coords
    
    # Create a single batch for raster data
    test_rasters_full = test_generator.get_raster_patches(test_coords_full)
    
    importance = calculate_permutation_importance(model, test_mlp_full, test_gnn_full, test_rasters_full, test_y_full, numeric_cols, raster_paths)
    for feature, score in importance.items():
        if feature not in all_feature_importances:
            all_feature_importances[feature] = []
        all_feature_importances[feature].append(score)

    del model, history, train_generator, test_generator
    gc.collect()

# Calculate and print final averages
avg_results = pd.DataFrame(fold_results).mean()
print("\n" + "="*80)
print(f"Final Cross-Validation Results (Averaged over {n_splits} folds):")
print("="*80)
print(f"Average R²: {avg_results['R2']:.4f}")
print(f"Average RMSE: {avg_results['RMSE']:.4f}")
print(f"Average MAE: {avg_results['MAE']:.4f}")
print(f"Average SMAPE: {avg_results['SMAPE']:.4f}%")

# Calculate and print average feature importance
print("\n--- Average Feature Importance (Permutation) ---")
avg_importance = {k: np.mean(v) for k, v in all_feature_importances.items()}
sorted_importance = sorted(avg_importance.items(), key=lambda item: item[1], reverse=True)
for feature, score in sorted_importance:
    print(f"{feature}: {score:.4f}")

# ==================== 7. Save all info to a folder ==================== #
# NOTE: Removed the file saving functionality as requested. The output is now
# printed directly to the console.

print("\nAnalysis complete. Results are printed above.")


Analyzing GNN-MLP-Raster Fusion Model with 5-Fold Cross-Validation
Using a uniform patch size of 100 pixels for a 500m buffer.

--- Starting Fold 1/5 ---


Epoch 1/100


  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 47ms/step - loss: 155860.8281 - val_loss: 14426.4189
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 46ms/step - loss: 101962.0391 - val_loss: 3614.8186
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 40ms/step - loss: 37470.3086 - val_loss: 4155.1509
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 38ms/step - loss: 7083.7373 - val_loss: 2953.2900
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 39ms/step - loss: 5079.8433 - val_loss: 2475.6829
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 41ms/step - loss: 3330.6868 - val_loss: 2208.4880
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 42ms/step - loss: 4411.3809 - val_loss: 2980.4490
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 41ms/step - loss: 3998.7439 - val_loss: 1424.346

  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 53ms/step - loss: 898850.8125 - val_loss: 100962.8047
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 57ms/step - loss: 76930.0781 - val_loss: 26016.9219
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 57ms/step - loss: 13245.5635 - val_loss: 10595.9238
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 52ms/step - loss: 10291.3584 - val_loss: 9867.9541
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 46ms/step - loss: 24398.7715 - val_loss: 5731.5601
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 51ms/step - loss: 6266.1401 - val_loss: 6293.6641
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 52ms/step - loss: 5192.0127 - val_loss: 5549.7563
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 51ms/step - loss: 11044.3740 - val_loss: 416

  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 56ms/step - loss: 63941.8008 - val_loss: 31298.9121
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 48ms/step - loss: 75460.6484 - val_loss: 5819.8115
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 44ms/step - loss: 17700.4863 - val_loss: 6886.2515
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 60ms/step - loss: 16747.4512 - val_loss: 2777.5063
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 53ms/step - loss: 7624.4854 - val_loss: 2883.0220
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 51ms/step - loss: 5295.5073 - val_loss: 2517.1250
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 57ms/step - loss: 3174.8770 - val_loss: 970.1010
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 52ms/step - loss: 3553.7744 - val_loss: 1243.1707


  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 62ms/step - loss: 169258.0625 - val_loss: 59306.2578
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 56ms/step - loss: 259168.7500 - val_loss: 17853.1504
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 62ms/step - loss: 30459.6953 - val_loss: 4875.4497
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 55ms/step - loss: 14287.8984 - val_loss: 3428.8110
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 58ms/step - loss: 3365.4368 - val_loss: 5042.1865
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 60ms/step - loss: 3339.4985 - val_loss: 3768.7290
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 56ms/step - loss: 4315.3833 - val_loss: 2215.3386
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 56ms/step - loss: 2168.8752 - val_loss: 1417.0

  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 59ms/step - loss: 395060.3750 - val_loss: 25604.2148
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 67ms/step - loss: 63321.7031 - val_loss: 3526.9565
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 76ms/step - loss: 17724.5312 - val_loss: 5470.0259
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 62ms/step - loss: 11514.1846 - val_loss: 2879.8716
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 63ms/step - loss: 6557.7617 - val_loss: 1100.4091
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 72ms/step - loss: 5382.2637 - val_loss: 900.7040
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 77ms/step - loss: 4155.2280 - val_loss: 1472.6522
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 67ms/step - loss: 4155.8301 - val_loss: 1298.4431

In [3]:
# ==================== 0. Necessary Imports and Setup ==================== #
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import pickle # Import the pickle library for saving objects
import lime
import lime.lime_tabular
from tensorflow.python.ops.numpy_ops import np_config

# Set a consistent seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Enable NumPy-like behavior in TensorFlow
np_config.enable_numpy_behavior()

# ==================== 1. Load Data ==================== #
# NOTE: This script assumes the following file paths are correct.
try:
    orig = pd.read_csv("../../data/WinterSeason1.csv")
    river_100 = pd.read_csv("../data/Samples_100W.csv")
except FileNotFoundError as e:
    print(f"Error: Required data file not found. Please check your file paths.")
    print(f"Details: {e}")
    sys.exit()

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# ==================== 2. Collect ALL Rasters and Metadata ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDWW/*.tif")

# Get the pixel resolution from the first raster to set a uniform patch size
try:
    with rasterio.open(raster_paths[0]) as src:
        pixel_size = src.transform.a
except IndexError:
    print("Error: No raster files found in the specified directories.")
    sys.exit()

# Create a dictionary to store raster metadata for fast access
raster_metadata = {}
for path in raster_paths:
    with rasterio.open(path) as src:
        raster_metadata[path] = {
            'transform': src.transform,
            'crs': src.crs,
            'width': src.width,
            'height': src.height
        }

# ==================== 3. Define a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    """
    Custom Keras Sequence for generating batches of data.
    Handles three different input types: MLP features, GNN features,
    and raster image patches, loading rasters on-the-fly to save memory.
    """
    def __init__(self, mlp_data, gnn_data, y, coords, raster_paths, buffer_radius_m, pixel_size, batch_size=4, shuffle=True):
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.coords = coords
        self.raster_paths = raster_paths
        # Calculate the uniform patch size in pixels based on the buffer radius and pixel size
        # We need a square patch, so the size is 2 * radius / pixel_size
        self.patch_size = int(round((2 * buffer_radius_m) / pixel_size))
        # Ensure patch size is at least 1 and is an even number for easy centering
        if self.patch_size % 2 != 0:
            self.patch_size += 1
        self.patch_size = max(self.patch_size, 2)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def get_raster_patches(self, coords_batch):
        """
        Extracts a patch of raster data for each coordinate in the batch.
        Loads rasters on-the-fly to save memory and robustly handles boundaries.
        """
        patches_for_rasters = []
        for path in self.raster_paths:
            patches_for_this_raster = []
            try:
                with rasterio.open(path) as src:
                    for lon, lat in coords_batch:
                        # Get pixel coordinates
                        row, col = src.index(lon, lat)
                    
                        # Define a window to read around the pixel, handling boundaries
                        half_patch = self.patch_size // 2
                        left = int(col - half_patch)
                        top = int(row - half_patch)
                        right = int(col + half_patch)
                        bottom = int(row + half_patch)

                        # Create a new, empty array for the final padded patch
                        padded_patch = np.zeros((self.patch_size, self.patch_size), dtype='float32')

                        # Calculate the window in the raster's coordinate space to read from
                        # And the offset in the padded_patch to write to
                        read_left = max(0, left)
                        read_top = max(0, top)
                        read_right = min(src.width, right)
                        read_bottom = min(src.height, bottom)

                        # Check if the calculated window has a valid size
                        read_width = read_right - read_left
                        read_height = read_bottom - read_top
                    
                        if read_width > 0 and read_height > 0:
                            write_left = read_left - left
                            write_top = read_top - top
                            write_right = write_left + read_width
                            write_bottom = write_top + read_height

                            # Create the window object for rasterio to read from
                            window = Window(read_left, read_top, read_width, read_height)

                            # Read the data from the raster
                            patch_data = src.read(1, window=window)
                            # Place the read data into the padded patch
                            padded_patch[write_top:write_bottom, write_left:write_right] = patch_data
                    
                        patches_for_this_raster.append(padded_patch)
            
                # Stack the patches for this raster
                patches_for_rasters.append(np.stack(patches_for_this_raster, axis=0))
            except Exception as e:
                # This handles cases where a raster file might be missing or corrupted
                patches_for_rasters.append(np.zeros((len(coords_batch), self.patch_size, self.patch_size), dtype='float32'))


        # Stack all raster patches together
        final_patches = np.stack(patches_for_rasters, axis=-1)
        return final_patches

    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        batch_coords = self.coords[batch_indices]
        
        # Get raster data for the current batch
        batch_rasters = self.get_raster_patches(batch_coords)
        
        # Return a dictionary of inputs and the output
        return {"mlp_input": batch_mlp, "gnn_input": batch_gnn, "raster_input": batch_rasters}, batch_y

# ==================== 4. Define GNN-MLP-Raster Fusion Model ==================== #
def build_fusion_model(mlp_dim, gnn_dim, raster_patch_size, num_rasters):
    """
    Builds the multi-input Keras model with branches for MLP, GNN, and Rasters.
    """
    # Inputs for all branches
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    raster_input = Input(shape=(raster_patch_size, raster_patch_size, num_rasters), name="raster_input")

    # --- MLP Branch ---
    mlp_embedding = Dense(128, activation="relu")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch ---
    gnn_embedding = Dense(128, activation="relu")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)
    
    # --- Raster Branch (using a simple CNN) ---
    raster_conv = Conv2D(32, (3, 3), activation="relu")(raster_input)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_conv = Conv2D(64, (3, 3), activation="relu")(raster_pool)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_flatten = Flatten()(raster_pool)
    raster_embedding = Dense(64, activation="relu", name="raster_embedding")(raster_flatten)

    # --- Concatenate Embeddings ---
    combined = Concatenate()([mlp_embedding, gnn_embedding, raster_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[mlp_input, gnn_input, raster_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

# ==================== 5. Define Evaluation & Importance Functions ==================== #
def calculate_smape(y_true, y_pred):
    """Calculates Symmetric Mean Absolute Percentage Error (SMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero
    mask = denominator == 0
    smape_val = np.where(mask, 0, numerator / denominator)
    return 100 * np.mean(smape_val)

def evaluate_model(model, data_inputs, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, MAE, and SMAPE.
    Handles both Keras Generators and direct numpy arrays.
    """
    if isinstance(data_inputs, DataGenerator):
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    else:
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    
    if return_preds:
        return y_pred
    else:
        # Align true labels with predictions if using a generator
        y_true_aligned = y_test[:len(y_pred)]
        r2 = r2_score(y_true_aligned, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true_aligned, y_pred))
        mae = mean_absolute_error(y_true_aligned, y_pred)
        smape = calculate_smape(y_true_aligned, y_pred)
        return r2, rmse, mae, smape

def calculate_permutation_importance(model, mlp_data, gnn_data, raster_data, y_true, mlp_features, raster_features):
    """
    Calculates permutation feature importance for all individual features.
    """
    print("\n--- Starting Permutation Feature Importance Analysis ---")
    
    # Create the combined input for the model
    initial_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
    
    # Get baseline R² on the unshuffled data
    baseline_r2, _, _, _ = evaluate_model(model, initial_inputs, y_true)
    print(f"Baseline R²: {baseline_r2:.4f}")
    
    importance = {}
    
    # 1. Permute individual MLP features
    print("Permuting MLP features...")
    for i, feature in enumerate(mlp_features):
        shuffled_mlp_data = mlp_data.copy()
        np.random.shuffle(shuffled_mlp_data[:, i])
        shuffled_inputs = {"mlp_input": shuffled_mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'MLP_{feature}'] = baseline_r2 - shuffled_r2
    
    # 2. Permute GNN input (as a single block)
    print("Permuting GNN features...")
    shuffled_gnn_data = gnn_data.copy()
    np.random.shuffle(shuffled_gnn_data)
    shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": shuffled_gnn_data, "raster_input": raster_data}
    shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
    importance['GNN'] = baseline_r2 - shuffled_r2
    
    # 3. Permute Raster inputs (each raster band as a feature)
    print("Permuting Raster features...")
    for i, feature in enumerate(raster_features):
        shuffled_raster_data = raster_data.copy()
        # Reshape the channel to a 2D array (samples, pixels) for easy shuffling
        reshaped_channel = shuffled_raster_data[:, :, :, i].reshape(shuffled_raster_data.shape[0], -1)
        # Shuffle each row independently to keep per-sample values
        np.random.shuffle(reshaped_channel)
        # Reshape back to the original shape
        shuffled_raster_data[:, :, :, i] = reshaped_channel.reshape(shuffled_raster_data.shape[0], shuffled_raster_data.shape[1], shuffled_raster_data.shape[2])
        shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": shuffled_raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'Raster_{os.path.basename(feature)}'] = baseline_r2 - shuffled_r2
        
    return importance

def calculate_intrinsic_importance(model, mlp_features, raster_features):
    """
    Calculates intrinsic feature importance based on the L2 norm of the weights
    of the connections from each branch's embedding layer to the first
    combined dense layer.
    """
    print("\n--- Starting Intrinsic Feature Importance Analysis ---")
    
    # Get the embedding layers
    mlp_embedding_layer = model.get_layer("mlp_embedding")
    gnn_embedding_layer = model.get_layer("gnn_embedding")
    raster_embedding_layer = model.get_layer("raster_embedding")

    # Get the weights connecting each branch's embedding layer to the output
    # For a simple feedforward network, we can look at the weights to the next layer
    
    # This assumes the first Dense layer after concatenation is the target
    # For a more rigorous approach, one would use integrated gradients or similar methods.
    
    # --- MLP Feature Importance ---
    # The weights from the MLP input to the first dense layer
    mlp_weights = model.get_layer(index=1).get_weights()[0]
    mlp_feature_importance = np.linalg.norm(mlp_weights, axis=1)
    
    print("\nIntrinsic Importance (MLP Features):")
    for feature, score in zip(mlp_features, mlp_feature_importance):
        # Use .item() to extract the scalar before formatting, regardless of type
        print(f"MLP_{feature}: {score.item():.4f}")
        
    # --- GNN Branch Importance (as a single unit) ---
    # The GNN input is the adjacency matrix, so we treat it as a single block
    gnn_weights = model.get_layer("gnn_embedding").get_weights()[0]
    gnn_branch_importance = np.linalg.norm(gnn_weights)
    print(f"\nIntrinsic Importance (GNN Branch): {gnn_branch_importance.item():.4f}")

    # --- Raster Channel Importance ---
    # The weights from the last CNN layer to the first dense layer in the raster branch
    # The weights are in the shape (pixels_flattened, embedding_dim)
    raster_embedding_weights = model.get_layer("raster_embedding").get_weights()[0]
    
    # We can get a rough per-channel importance by summing the absolute weights for each channel
    # This is an approximation as the CNN learns complex spatial features
    # A more precise method would be to analyze the filters, but this is a good proxy.
    raster_input_shape = model.get_layer("raster_input").input_shape[1:]
    num_rasters = raster_input_shape[-1]
    
    # Get the weights of the first Conv2D layer
    first_conv_weights = model.get_layer(index=3).get_weights()[0]
    
    print("\nIntrinsic Importance (Raster Channels - based on first layer filters):")
    for i in range(num_rasters):
        channel_weights = first_conv_weights[:, :, i, :]
        importance_score = np.linalg.norm(channel_weights)
        print(f"Raster_{os.path.basename(raster_features[i])}: {importance_score.item():.4f}")
    
def calculate_lime_importance(model, test_mlp_data, test_gnn_data, test_raster_data, mlp_features, raster_features):
    """
    Calculates LIME (Local Interpretable Model-agnostic Explanations) importance.
    LIME is applied to a combined set of MLP and flattened raster features,
    as GNN input is context-dependent and not suitable for LIME.
    Note: LIME can be memory intensive, so we use a small number of samples.
    """
    print("\n--- Starting LIME Feature Importance Analysis ---")
    
    # Flatten the raster data to a 2D array
    flat_raster_data = test_raster_data.reshape(test_raster_data.shape[0], -1)
    
    # Combine MLP and flattened raster data for LIME
    combined_data = np.hstack([test_mlp_data, flat_raster_data])
    
    # Create the full list of feature names for the combined data
    raster_feature_names = [f"Raster_{os.path.basename(path)}_{i}" for path in raster_features for i in range(test_raster_data.shape[1] * test_raster_data.shape[2])]
    feature_names = list(mlp_features) + raster_feature_names
    
    # Define a prediction function that LIME can use
    def predict_fn(x):
        # Unpack the combined features back to their original shapes
        mlp_slice = x[:, :len(mlp_features)]
        raster_slice = x[:, len(mlp_features):].reshape(x.shape[0], test_raster_data.shape[1], test_raster_data.shape[2], len(raster_features))
        
        # We need a dummy GNN input for the model prediction
        dummy_gnn = np.zeros((x.shape[0], test_gnn_data.shape[1]))
        
        # Return the model's predictions (LIME expects a single value per sample)
        return model.predict({"mlp_input": mlp_slice, "gnn_input": dummy_gnn, "raster_input": raster_slice}, verbose=0)
    
    # Initialize the LIME explainer
    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=combined_data, 
        feature_names=feature_names, 
        class_names=["RI Prediction"], 
        mode='regression'
    )
    
    # Choose a few samples to explain
    num_samples = 3 # Reduced to 3 to avoid memory issues
    sample_indices = np.random.choice(range(len(test_mlp_data)), num_samples, replace=False)
    
    lime_importance_scores = {}
    
    for idx in sample_indices:
        print(f"Generating LIME explanation for sample {idx}...")
        explanation = explainer.explain_instance(
            data_row=combined_data[idx], 
            predict_fn=predict_fn, 
            num_features=10 # Explain the top 10 most important features, as requested
        )
        for feature, weight in explanation.as_list():
            if feature not in lime_importance_scores:
                lime_importance_scores[feature] = []
            lime_importance_scores[feature].append(abs(weight))
            
    # Aggregate and average the importance scores
    avg_lime_importance = {
        feature: np.mean(scores) for feature, scores in lime_importance_scores.items()
    }
    
    # Sort and print the top 10 features
    print("\nTop 10 LIME Features (Average Absolute Weight):")
    sorted_lime = sorted(avg_lime_importance.items(), key=lambda item: item[1], reverse=True)
    for feature, score in sorted_lime[:10]:
        print(f"{feature}: {score:.4f}")
    
    return avg_lime_importance

# ==================== 6. Main Analysis without K-Fold CV ==================== #

print("\n" + "="*80)
print("Analyzing GNN-MLP-Raster Fusion Model (Single Run)")
print(f"Using a uniform patch size of {int(round((2 * 500) / pixel_size))} pixels for a 500m buffer.")
print("="*80)

# Combine all data
full_data = pd.concat([orig, river_100], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
full_coords = full_data[['Long','Lat']].values
full_y = full_data['RI'].values
full_mlp_data = full_data[numeric_cols].values

# Pre-process MLP data with StandardScaler
scaler = StandardScaler()
full_mlp_data = scaler.fit_transform(full_mlp_data)

# Split data into training and testing sets (e.g., 80% train, 20% test)
train_mlp, test_mlp, train_coords, test_coords, y_train, y_test = train_test_split(
    full_mlp_data, full_coords, full_y, test_size=0.2, random_state=42
)

# Prepare GNN input (adjacency matrix based on distances)
dist_mat_train = distance_matrix(train_coords, train_coords)
gnn_train = np.exp(-dist_mat_train / 10)
    
dist_mat_test_train = distance_matrix(test_coords, train_coords)
gnn_test = np.exp(-dist_mat_test_train / 10)

# Clean up memory
del dist_mat_train, dist_mat_test_train
gc.collect()

# Define patch size and number of rasters
buffer_radius_m = 500
raster_patch_size = int(round((2 * buffer_radius_m) / pixel_size))
if raster_patch_size % 2 != 0:
    raster_patch_size += 1
raster_patch_size = max(raster_patch_size, 2)
num_rasters = len(raster_paths)

# Build and compile the model
model = build_fusion_model(mlp_dim=train_mlp.shape[1], gnn_dim=gnn_train.shape[1], 
                             raster_patch_size=raster_patch_size, num_rasters=num_rasters)

# Print model summary for inspection
model.summary()
    
# Create data generators for training and testing
train_generator = DataGenerator(
    mlp_data=train_mlp, gnn_data=gnn_train, y=y_train, coords=train_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=True
)

test_generator = DataGenerator(
    mlp_data=test_mlp, gnn_data=gnn_test, y=y_test, coords=test_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=False
)
    
# Train the model
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True
)

print("\n--- Model Training ---")
history = model.fit(
    train_generator,
    epochs=100,
    verbose=1,
    callbacks=[early_stopping],
    validation_data=test_generator
)

# Evaluate on the test data
r2_test, rmse_test, mae_test, smape_test = evaluate_model(model, test_generator, y_test)
    
print("\n" + "="*80)
print("Final Model Performance on Test Set")
print("="*80)
print(f"R²: {r2_test:.4f}")
print(f"RMSE: {rmse_test:.4f}")
print(f"MAE: {mae_test:.4f}")
print(f"SMAPE: {smape_test:.4f}%")

# ==================== 7. Feature Importance Analysis ==================== #

# --- Prepare data for importance functions (needs to be full numpy arrays) ---
# Get all test data from the generator
test_mlp_full = test_generator.mlp_data
test_gnn_full = test_generator.gnn_data
test_y_full = test_generator.y
test_coords_full = test_generator.coords
test_rasters_full = test_generator.get_raster_patches(test_coords_full)

# --- Permutation Importance ---
permutation_importance_scores = calculate_permutation_importance(
    model, 
    test_mlp_full, 
    test_gnn_full, 
    test_rasters_full, 
    test_y_full, 
    numeric_cols, 
    raster_paths
)
print("\n--- Summary of Permutation Importance ---")
sorted_perm_importance = sorted(permutation_importance_scores.items(), key=lambda item: item[1], reverse=True)
for feature, score in sorted_perm_importance:
    print(f"{feature}: {score:.4f}")


Analyzing GNN-MLP-Raster Fusion Model (Single Run)
Using a uniform patch size of 100 pixels for a 500m buffer.



--- Model Training ---
Epoch 1/100


  self._warn_if_super_not_called()


[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 175ms/step - loss: 222932.2969 - val_loss: 28009.3730
Epoch 2/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 111ms/step - loss: 223441.4375 - val_loss: 4882.5195
Epoch 3/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 102ms/step - loss: 29547.0410 - val_loss: 5285.2686
Epoch 4/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 87ms/step - loss: 14201.7383 - val_loss: 4018.3577
Epoch 5/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 99ms/step - loss: 12184.5234 - val_loss: 3284.9578
Epoch 6/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 87ms/step - loss: 13595.1885 - val_loss: 2380.7761
Epoch 7/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 112ms/step - loss: 3847.5535 - val_loss: 2292.3977
Epoch 8/100
[1m23/23[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 110ms/step - loss: 5849.3936 - val_loss: 

In [1]:
# ==================== 0. Necessary Imports and Setup ==================== #
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import pickle # Import the pickle library for saving objects
import lime
import lime.lime_tabular
from tensorflow.python.ops.numpy_ops import np_config

# Set a consistent seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Enable NumPy-like behavior in TensorFlow
np_config.enable_numpy_behavior()

# ==================== 1. Load Data ==================== #
# NOTE: This script assumes the following file paths are correct.
try:
    orig = pd.read_csv("../../data/WinterSeason1.csv")
    river_100 = pd.read_csv("../data/Samples_100W.csv")
except FileNotFoundError as e:
    print(f"Error: Required data file not found. Please check your file paths.")
    print(f"Details: {e}")
    sys.exit()

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# ==================== 2. Collect ALL Rasters and Metadata ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDWW/*.tif")

# Get the pixel resolution from the first raster to set a uniform patch size
try:
    with rasterio.open(raster_paths[0]) as src:
        pixel_size = src.transform.a
except IndexError:
    print("Error: No raster files found in the specified directories.")
    sys.exit()

# Create a dictionary to store raster metadata for fast access
raster_metadata = {}
for path in raster_paths:
    with rasterio.open(path) as src:
        raster_metadata[path] = {
            'transform': src.transform,
            'crs': src.crs,
            'width': src.width,
            'height': src.height
        }

# ==================== 3. Define a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    """
    Custom Keras Sequence for generating batches of data.
    Handles three different input types: MLP features, GNN features,
    and raster image patches, loading rasters on-the-fly to save memory.
    """
    def __init__(self, mlp_data, gnn_data, y, coords, raster_paths, buffer_radius_m, pixel_size, batch_size=4, shuffle=True):
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.coords = coords
        self.raster_paths = raster_paths
        # Calculate the uniform patch size in pixels based on the buffer radius and pixel size
        # We need a square patch, so the size is 2 * radius / pixel_size
        self.patch_size = int(round((2 * buffer_radius_m) / pixel_size))
        # Ensure patch size is at least 1 and is an even number for easy centering
        if self.patch_size % 2 != 0:
            self.patch_size += 1
        self.patch_size = max(self.patch_size, 2)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def get_raster_patches(self, coords_batch):
        """
        Extracts a patch of raster data for each coordinate in the batch.
        Loads rasters on-the-fly to save memory and robustly handles boundaries.
        """
        patches_for_rasters = []
        for path in self.raster_paths:
            patches_for_this_raster = []
            try:
                with rasterio.open(path) as src:
                    for lon, lat in coords_batch:
                        # Get pixel coordinates
                        row, col = src.index(lon, lat)
                    
                        # Define a window to read around the pixel, handling boundaries
                        half_patch = self.patch_size // 2
                        left = int(col - half_patch)
                        top = int(row - half_patch)
                        right = int(col + half_patch)
                        bottom = int(row + half_patch)

                        # Create a new, empty array for the final padded patch
                        padded_patch = np.zeros((self.patch_size, self.patch_size), dtype='float32')

                        # Calculate the window in the raster's coordinate space to read from
                        # And the offset in the padded_patch to write to
                        read_left = max(0, left)
                        read_top = max(0, top)
                        read_right = min(src.width, right)
                        read_bottom = min(src.height, bottom)

                        # Check if the calculated window has a valid size
                        read_width = read_right - read_left
                        read_height = read_bottom - read_top
                    
                        if read_width > 0 and read_height > 0:
                            write_left = read_left - left
                            write_top = read_top - top
                            write_right = write_left + read_width
                            write_bottom = write_top + read_height

                            # Create the window object for rasterio to read from
                            window = Window(read_left, read_top, read_width, read_height)

                            # Read the data from the raster
                            patch_data = src.read(1, window=window)
                            # Place the read data into the padded patch
                            padded_patch[write_top:write_bottom, write_left:write_right] = patch_data
                    
                        patches_for_this_raster.append(padded_patch)
            
                # Stack the patches for this raster
                patches_for_rasters.append(np.stack(patches_for_this_raster, axis=0))
            except Exception as e:
                # This handles cases where a raster file might be missing or corrupted
                patches_for_rasters.append(np.zeros((len(coords_batch), self.patch_size, self.patch_size), dtype='float32'))


        # Stack all raster patches together
        final_patches = np.stack(patches_for_rasters, axis=-1)
        return final_patches

    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        batch_coords = self.coords[batch_indices]
        
        # Get raster data for the current batch
        batch_rasters = self.get_raster_patches(batch_coords)
        
        # Return a dictionary of inputs and the output
        return {"mlp_input": batch_mlp, "gnn_input": batch_gnn, "raster_input": batch_rasters}, batch_y

# ==================== 4. Define GNN-MLP-Raster Fusion Model ==================== #
def build_fusion_model(mlp_dim, gnn_dim, raster_patch_size, num_rasters):
    """
    Builds the multi-input Keras model with branches for MLP, GNN, and Rasters.
    """
    # Inputs for all branches
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    raster_input = Input(shape=(raster_patch_size, raster_patch_size, num_rasters), name="raster_input")

    # --- MLP Branch ---
    mlp_embedding = Dense(128, activation="relu")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch ---
    gnn_embedding = Dense(128, activation="relu")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)
    
    # --- Raster Branch (using a simple CNN) ---
    raster_conv = Conv2D(32, (3, 3), activation="relu")(raster_input)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_conv = Conv2D(64, (3, 3), activation="relu")(raster_pool)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_flatten = Flatten()(raster_pool)
    raster_embedding = Dense(64, activation="relu", name="raster_embedding")(raster_flatten)

    # --- Concatenate Embeddings ---
    combined = Concatenate()([mlp_embedding, gnn_embedding, raster_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[mlp_input, gnn_input, raster_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

# ==================== 5. Define Evaluation & Importance Functions ==================== #
def calculate_smape(y_true, y_pred):
    """Calculates Symmetric Mean Absolute Percentage Error (SMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero
    mask = denominator == 0
    smape_val = np.where(mask, 0, numerator / denominator)
    return 100 * np.mean(smape_val)

def evaluate_model(model, data_inputs, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, MAE, and SMAPE.
    Handles both Keras Generators and direct numpy arrays.
    """
    if isinstance(data_inputs, DataGenerator):
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    else:
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    
    if return_preds:
        return y_pred
    else:
        # Align true labels with predictions if using a generator
        y_true_aligned = y_test[:len(y_pred)]
        r2 = r2_score(y_true_aligned, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true_aligned, y_pred))
        mae = mean_absolute_error(y_true_aligned, y_pred)
        smape = calculate_smape(y_true_aligned, y_pred)
        return r2, rmse, mae, smape

def calculate_permutation_importance(model, mlp_data, gnn_data, raster_data, y_true, mlp_features, raster_features):
    """
    Calculates permutation feature importance for all individual features.
    """
    print("\n--- Starting Permutation Feature Importance Analysis ---")
    
    # Create the combined input for the model
    initial_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
    
    # Get baseline R² on the unshuffled data
    baseline_r2, _, _, _ = evaluate_model(model, initial_inputs, y_true)
    print(f"Baseline R²: {baseline_r2:.4f}")
    
    importance = {}
    
    # 1. Permute individual MLP features
    print("Permuting MLP features...")
    for i, feature in enumerate(mlp_features):
        shuffled_mlp_data = mlp_data.copy()
        np.random.shuffle(shuffled_mlp_data[:, i])
        shuffled_inputs = {"mlp_input": shuffled_mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'MLP_{feature}'] = baseline_r2 - shuffled_r2
    
    # 2. Permute GNN input (as a single block)
    print("Permuting GNN features...")
    shuffled_gnn_data = gnn_data.copy()
    np.random.shuffle(shuffled_gnn_data)
    shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": shuffled_gnn_data, "raster_input": raster_data}
    shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
    importance['GNN'] = baseline_r2 - shuffled_r2
    
    # 3. Permute Raster inputs (each raster band as a feature)
    print("Permuting Raster features...")
    for i, feature in enumerate(raster_features):
        shuffled_raster_data = raster_data.copy()
        # Reshape the channel to a 2D array (samples, pixels) for easy shuffling
        reshaped_channel = shuffled_raster_data[:, :, :, i].reshape(shuffled_raster_data.shape[0], -1)
        # Shuffle each row independently to keep per-sample values
        np.random.shuffle(reshaped_channel)
        # Reshape back to the original shape
        shuffled_raster_data[:, :, :, i] = reshaped_channel.reshape(shuffled_raster_data.shape[0], shuffled_raster_data.shape[1], shuffled_raster_data.shape[2])
        shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": shuffled_raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'Raster_{os.path.basename(feature)}'] = baseline_r2 - shuffled_r2
        
    return importance

def calculate_intrinsic_importance(model, mlp_features, raster_features):
    """
    Calculates intrinsic feature importance based on the L2 norm of the weights
    of the connections from each branch's embedding layer to the first
    combined dense layer.
    """
    print("\n--- Starting Intrinsic Feature Importance Analysis ---")
    
    # Get the embedding layers
    mlp_embedding_layer = model.get_layer("mlp_embedding")
    gnn_embedding_layer = model.get_layer("gnn_embedding")
    raster_embedding_layer = model.get_layer("raster_embedding")

    # Get the weights connecting each branch's embedding layer to the output
    # For a simple feedforward network, we can look at the weights to the next layer
    
    # This assumes the first Dense layer after concatenation is the target
    # For a more rigorous approach, one would use integrated gradients or similar methods.
    
    # --- MLP Feature Importance ---
    # The weights from the MLP input to the first dense layer
    mlp_weights = model.get_layer(index=1).get_weights()[0]
    mlp_feature_importance = np.linalg.norm(mlp_weights, axis=1)
    
    print("\nIntrinsic Importance (MLP Features):")
    for feature, score in zip(mlp_features, mlp_feature_importance):
        # Use .item() to extract the scalar before formatting, regardless of type
        print(f"MLP_{feature}: {score.item():.4f}")
        
    # --- GNN Branch Importance (as a single unit) ---
    # The GNN input is the adjacency matrix, so we treat it as a single block
    gnn_weights = model.get_layer("gnn_embedding").get_weights()[0]
    gnn_branch_importance = np.linalg.norm(gnn_weights)
    print(f"\nIntrinsic Importance (GNN Branch): {gnn_branch_importance.item():.4f}")

    # --- Raster Channel Importance ---
    # The weights from the last CNN layer to the first dense layer in the raster branch
    # The weights are in the shape (pixels_flattened, embedding_dim)
    raster_embedding_weights = model.get_layer("raster_embedding").get_weights()[0]
    
    # We can get a rough per-channel importance by summing the absolute weights for each channel
    # This is an approximation as the CNN learns complex spatial features
    # A more precise method would be to analyze the filters, but this is a good proxy.
    raster_input_shape = model.get_layer("raster_input").input_shape[1:]
    num_rasters = raster_input_shape[-1]
    
    # Get the weights of the first Conv2D layer
    first_conv_weights = model.get_layer(index=3).get_weights()[0]
    
    print("\nIntrinsic Importance (Raster Channels - based on first layer filters):")
    for i in range(num_rasters):
        channel_weights = first_conv_weights[:, :, i, :]
        importance_score = np.linalg.norm(channel_weights)
        print(f"Raster_{os.path.basename(raster_features[i])}: {importance_score.item():.4f}")
    
def calculate_lime_importance(model, test_mlp_data, test_gnn_data, test_raster_data, mlp_features, raster_features):
    """
    Calculates LIME (Local Interpretable Model-agnostic Explanations) importance.
    LIME is applied to a combined set of MLP and flattened raster features,
    as GNN input is context-dependent and not suitable for LIME.
    Note: LIME can be memory intensive, so we use a small number of samples.
    """
    print("\n--- Starting LIME Feature Importance Analysis ---")
    
    # Flatten the raster data to a 2D array
    flat_raster_data = test_raster_data.reshape(test_raster_data.shape[0], -1)
    
    # Combine MLP and flattened raster data for LIME
    combined_data = np.hstack([test_mlp_data, flat_raster_data])
    
    # Create the full list of feature names for the combined data
    raster_feature_names = [f"Raster_{os.path.basename(path)}_{i}" for path in raster_features for i in range(test_raster_data.shape[1] * test_raster_data.shape[2])]
    feature_names = list(mlp_features) + raster_feature_names
    
    # Define a prediction function that LIME can use
    def predict_fn(x):
        # Unpack the combined features back to their original shapes
        mlp_slice = x[:, :len(mlp_features)]
        raster_slice = x[:, len(mlp_features):].reshape(x.shape[0], test_raster_data.shape[1], test_raster_data.shape[2], len(raster_features))
        
        # We need a dummy GNN input for the model prediction
        dummy_gnn = np.zeros((x.shape[0], test_gnn_data.shape[1]))
        
        # Return the model's predictions (LIME expects a single value per sample)
        return model.predict({"mlp_input": mlp_slice, "gnn_input": dummy_gnn, "raster_input": raster_slice}, verbose=0)
    
    # Initialize the LIME explainer
    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=combined_data, 
        feature_names=feature_names, 
        class_names=["RI Prediction"], 
        mode='regression'
    )
    
    # Choose a few samples to explain
    num_samples = 3 # Reduced to 3 to avoid memory issues
    sample_indices = np.random.choice(range(len(test_mlp_data)), num_samples, replace=False)
    
    lime_importance_scores = {}
    
    for idx in sample_indices:
        print(f"Generating LIME explanation for sample {idx}...")
        explanation = explainer.explain_instance(
            data_row=combined_data[idx], 
            predict_fn=predict_fn, 
            num_features=10 # Explain the top 10 most important features, as requested
        )
        for feature, weight in explanation.as_list():
            if feature not in lime_importance_scores:
                lime_importance_scores[feature] = []
            lime_importance_scores[feature].append(abs(weight))
            
    # Aggregate and average the importance scores
    avg_lime_importance = {
        feature: np.mean(scores) for feature, scores in lime_importance_scores.items()
    }
    
    # Sort and print the top 10 features
    print("\nTop 10 LIME Features (Average Absolute Weight):")
    sorted_lime = sorted(avg_lime_importance.items(), key=lambda item: item[1], reverse=True)
    for feature, score in sorted_lime[:10]:
        print(f"{feature}: {score:.4f}")
    
    return avg_lime_importance

# ==================== 6. Main Analysis without K-Fold CV ==================== #

print("\n" + "="*80)
print("Analyzing GNN-MLP-Raster Fusion Model (Single Run)")
print(f"Using a uniform patch size of {int(round((2 * 500) / pixel_size))} pixels for a 500m buffer.")
print("="*80)

# Combine all data
full_data = pd.concat([orig, river_100], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
full_coords = full_data[['Long','Lat']].values
full_y = full_data['RI'].values
full_mlp_data = full_data[numeric_cols].values

# Pre-process MLP data with StandardScaler
scaler = StandardScaler()
full_mlp_data = scaler.fit_transform(full_mlp_data)

# Split data into training and testing sets (e.g., 80% train, 20% test)
train_mlp, test_mlp, train_coords, test_coords, y_train, y_test = train_test_split(
    full_mlp_data, full_coords, full_y, test_size=0.2, random_state=42
)

# Prepare GNN input (adjacency matrix based on distances)
dist_mat_train = distance_matrix(train_coords, train_coords)
gnn_train = np.exp(-dist_mat_train / 10)
    
dist_mat_test_train = distance_matrix(test_coords, train_coords)
gnn_test = np.exp(-dist_mat_test_train / 10)

# Clean up memory
del dist_mat_train, dist_mat_test_train
gc.collect()

# Define patch size and number of rasters
buffer_radius_m = 500
raster_patch_size = int(round((2 * buffer_radius_m) / pixel_size))
if raster_patch_size % 2 != 0:
    raster_patch_size += 1
raster_patch_size = max(raster_patch_size, 2)
num_rasters = len(raster_paths)

# Build and compile the model
model = build_fusion_model(mlp_dim=train_mlp.shape[1], gnn_dim=gnn_train.shape[1], 
                             raster_patch_size=raster_patch_size, num_rasters=num_rasters)

# Print model summary for inspection
model.summary()
    
# Create data generators for training and testing
train_generator = DataGenerator(
    mlp_data=train_mlp, gnn_data=gnn_train, y=y_train, coords=train_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=True
)

test_generator = DataGenerator(
    mlp_data=test_mlp, gnn_data=gnn_test, y=y_test, coords=test_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=False
)
    
# Train the model
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True
)

print("\n--- Model Training ---")
history = model.fit(
    train_generator,
    epochs=100,
    verbose=0,
    callbacks=[early_stopping],
    validation_data=test_generator
)

# ==================== 7. Feature Importance Analysis ==================== #

# --- Prepare data for importance functions (needs to be full numpy arrays) ---
# Get all test data from the generator
test_mlp_full = test_generator.mlp_data
test_gnn_full = test_generator.gnn_data
test_y_full = test_generator.y
test_coords_full = test_generator.coords
test_rasters_full = test_generator.get_raster_patches(test_coords_full)

# --- Permutation Importance ---
permutation_importance_scores = calculate_permutation_importance(
    model, 
    test_mlp_full, 
    test_gnn_full, 
    test_rasters_full, 
    test_y_full, 
    numeric_cols, 
    raster_paths
)
print("\n--- Summary of Permutation Importance ---")
sorted_perm_importance = sorted(permutation_importance_scores.items(), key=lambda item: item[1], reverse=True)
for feature, score in sorted_perm_importance:
    print(f"{feature}: {score:.4f}")


Analyzing GNN-MLP-Raster Fusion Model (Single Run)
Using a uniform patch size of 100 pixels for a 500m buffer.



--- Model Training ---


  self._warn_if_super_not_called()



--- Starting Permutation Feature Importance Analysis ---
Baseline R²: 0.9537
Permuting MLP features...
Permuting GNN features...
Permuting Raster features...

--- Summary of Permutation Importance ---
Raster_PbW.tif: 1.8036
Raster_NiW.tif: 0.1315
Raster_SiltW.tif: 0.0815
Raster_CrW.tif: 0.0520
Raster_AsW.tif: 0.0498
Raster_CuW.tif: 0.0405
Raster_SandW.tif: 0.0144
MLP_CuW: 0.0039
MLP_FeW: 0.0026
MLP_SiltW: 0.0023
MLP_MW: 0.0017
MLP_CrW: 0.0014
Raster_ClayW.tif: 0.0009
MLP_SandW: 0.0007
MLP_hydro_dist_ind: 0.0006
MLP_NiW: 0.0004
MLP_CdW: 0.0002
GNN: 0.0001
MLP_num_brick_field: 0.0001
MLP_num_industry: 0.0000
Raster_bui.tif: 0.0000
Raster_ndsi.tif: 0.0000
Raster_savi.tif: 0.0000
Raster_ndbsi.tif: 0.0000
Raster_ui.tif: 0.0000
Raster_ndwi.tif: 0.0000
Raster_ndbi.tif: 0.0000
Raster_awei.tif: 0.0000
Raster_evi.tif: 0.0000
Raster_mndwi.tif: 0.0000
Raster_ndvi.tif: 0.0000
Raster_LULC2020.tif: 0.0000
Raster_LULC2021.tif: 0.0000
Raster_LULC2022.tif: 0.0000
Raster_LULC2019.tif: 0.0000
Raster_LULC

In [3]:
# ==================== 0. Necessary Imports and Setup ==================== #
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import pickle # Import the pickle library for saving objects
from tensorflow.python.ops.numpy_ops import np_config

# Import LIME components for explanation
try:
    from lime.lime_tabular import LimeTabularExplainer
    from lime.lime_image import LimeImageExplainer
except ImportError:
    print("LIME is not installed. Please install it using: pip install lime")
    sys.exit()

# Set a consistent seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Enable NumPy-like behavior in TensorFlow
np_config.enable_numpy_behavior()

# ==================== 1. Load Data ==================== #
# NOTE: This script assumes the following file paths are correct.
try:
    orig = pd.read_csv("../../data/WinterSeason1.csv")
    river_100 = pd.read_csv("../data/Samples_100W.csv")
except FileNotFoundError as e:
    print(f"Error: Required data file not found. Please check your file paths.")
    print(f"Details: {e}")
    sys.exit()

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# ==================== 2. Collect ALL Rasters and Metadata ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDWW/*.tif")

# Get the pixel resolution from the first raster to set a uniform patch size
try:
    with rasterio.open(raster_paths[0]) as src:
        pixel_size = src.transform.a
except IndexError:
    print("Error: No raster files found in the specified directories.")
    sys.exit()

# Create a dictionary to store raster metadata for fast access
raster_metadata = {}
for path in raster_paths:
    with rasterio.open(path) as src:
        raster_metadata[path] = {
            'transform': src.transform,
            'crs': src.crs,
            'width': src.width,
            'height': src.height
        }

# ==================== 3. Define a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    """
    Custom Keras Sequence for generating batches of data.
    Handles three different input types: MLP features, GNN features,
    and raster image patches, loading rasters on-the-fly to save memory.
    """
    def __init__(self, mlp_data, gnn_data, y, coords, raster_paths, buffer_radius_m, pixel_size, batch_size=4, shuffle=True):
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.coords = coords
        self.raster_paths = raster_paths
        # Calculate the uniform patch size in pixels based on the buffer radius and pixel size
        self.patch_size = int(round((2 * buffer_radius_m) / pixel_size))
        if self.patch_size % 2 != 0:
            self.patch_size += 1
        self.patch_size = max(self.patch_size, 2)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def get_raster_patches(self, coords_batch):
        patches_for_rasters = []
        for path in self.raster_paths:
            patches_for_this_raster = []
            try:
                with rasterio.open(path) as src:
                    for lon, lat in coords_batch:
                        row, col = src.index(lon, lat)
                        half_patch = self.patch_size // 2
                        left = int(col - half_patch)
                        top = int(row - half_patch)
                        right = int(col + half_patch)
                        bottom = int(row + half_patch)
                        padded_patch = np.zeros((self.patch_size, self.patch_size), dtype='float32')
                        read_left = max(0, left)
                        read_top = max(0, top)
                        read_right = min(src.width, right)
                        read_bottom = min(src.height, bottom)
                        read_width = read_right - read_left
                        read_height = read_bottom - read_top
                    
                        if read_width > 0 and read_height > 0:
                            write_left = read_left - left
                            write_top = read_top - top
                            write_right = write_left + read_width
                            write_bottom = write_top + read_height
                            window = Window(read_left, read_top, read_width, read_height)
                            patch_data = src.read(1, window=window)
                            padded_patch[write_top:write_bottom, write_left:write_right] = patch_data
                        patches_for_this_raster.append(padded_patch)
                patches_for_rasters.append(np.stack(patches_for_this_raster, axis=0))
            except Exception as e:
                patches_for_rasters.append(np.zeros((len(coords_batch), self.patch_size, self.patch_size), dtype='float32'))
        final_patches = np.stack(patches_for_rasters, axis=-1)
        return final_patches

    def __getitem__(self, index):
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        batch_coords = self.coords[batch_indices]
        batch_rasters = self.get_raster_patches(batch_coords)
        return {"mlp_input": batch_mlp, "gnn_input": batch_gnn, "raster_input": batch_rasters}, batch_y

# ==================== 4. Define GNN-MLP-Raster Fusion Model ==================== #
def build_fusion_model(mlp_dim, gnn_dim, raster_patch_size, num_rasters):
    """
    Builds the multi-input Keras model with branches for MLP, GNN, and Rasters.
    """
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    raster_input = Input(shape=(raster_patch_size, raster_patch_size, num_rasters), name="raster_input")

    mlp_embedding = Dense(128, activation="relu")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    gnn_embedding = Dense(128, activation="relu")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)
    
    raster_conv = Conv2D(32, (3, 3), activation="relu")(raster_input)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_conv = Conv2D(64, (3, 3), activation="relu")(raster_pool)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_flatten = Flatten()(raster_pool)
    raster_embedding = Dense(64, activation="relu", name="raster_embedding")(raster_flatten)

    combined = Concatenate()([mlp_embedding, gnn_embedding, raster_embedding])
    
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    model = Model(inputs=[mlp_input, gnn_input, raster_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

# ==================== 5. Define Evaluation & Importance Functions ==================== #
def calculate_smape(y_true, y_pred):
    """Calculates Symmetric Mean Absolute Percentage Error (SMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    mask = denominator == 0
    smape_val = np.where(mask, 0, numerator / denominator)
    return 100 * np.mean(smape_val)

def evaluate_model(model, data_inputs, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, MAE, and SMAPE.
    """
    if isinstance(data_inputs, DataGenerator):
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    else:
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    
    if return_preds:
        return y_pred
    else:
        y_true_aligned = y_test[:len(y_pred)]
        r2 = r2_score(y_true_aligned, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true_aligned, y_pred))
        mae = mean_absolute_error(y_true_aligned, y_pred)
        smape = calculate_smape(y_true_aligned, y_pred)
        return r2, rmse, mae, smape

# ==================== 6. Main Analysis without K-Fold CV ==================== #
print("\n" + "="*80)
print("Analyzing GNN-MLP-Raster Fusion Model (Single Run)")
print(f"Using a uniform patch size of {int(round((2 * 500) / pixel_size))} pixels for a 500m buffer.")
print("="*80)

full_data = pd.concat([orig, river_100], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
full_coords = full_data[['Long','Lat']].values
full_y = full_data['RI'].values
full_mlp_data = full_data[numeric_cols].values

scaler = StandardScaler()
full_mlp_data = scaler.fit_transform(full_mlp_data)

train_mlp, test_mlp, train_coords, test_coords, y_train, y_test = train_test_split(
    full_mlp_data, full_coords, full_y, test_size=0.2, random_state=42
)

dist_mat_train = distance_matrix(train_coords, train_coords)
gnn_train = np.exp(-dist_mat_train / 10)
    
dist_mat_test_train = distance_matrix(test_coords, train_coords)
gnn_test = np.exp(-dist_mat_test_train / 10)

del dist_mat_train, dist_mat_test_train
gc.collect()

buffer_radius_m = 500
raster_patch_size = int(round((2 * buffer_radius_m) / pixel_size))
if raster_patch_size % 2 != 0:
    raster_patch_size += 1
raster_patch_size = max(raster_patch_size, 2)
num_rasters = len(raster_paths)

model = build_fusion_model(mlp_dim=train_mlp.shape[1], gnn_dim=gnn_train.shape[1], 
                             raster_patch_size=raster_patch_size, num_rasters=num_rasters)
    
train_generator = DataGenerator(
    mlp_data=train_mlp, gnn_data=gnn_train, y=y_train, coords=train_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=True
)

test_generator = DataGenerator(
    mlp_data=test_mlp, gnn_data=gnn_test, y=y_test, coords=test_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=False
)
    
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True
)

print("\n--- Model Training ---")
history = model.fit(
    train_generator,
    epochs=100,
    verbose=0,
    callbacks=[early_stopping],
    validation_data=test_generator
)

# ==================== 7. LIME Feature Importance Analysis ==================== #

def predict_fn_for_lime(tabular_data, image_data, gnn_data_single_sample):
    """
    A wrapper prediction function for LIME.
    It combines the perturbed tabular and image data with the fixed GNN data
    to make a prediction.
    """
    # GNN data is fixed for LIME analysis as it's not a feature of a single sample
    gnn_batch = np.tile(gnn_data_single_sample, (tabular_data.shape[0], 1))
    
    # Run prediction on the combined data
    predictions = model.predict({
        "mlp_input": tabular_data,
        "gnn_input": gnn_batch,
        "raster_input": image_data
    }, verbose=0)
    
    # Return predictions in the format LIME expects for regression
    return predictions

# --- Prepare data for LIME analysis (needs to be full numpy arrays) ---
# Get all test data from the generator
test_mlp_full = test_generator.mlp_data
test_gnn_full = test_generator.gnn_data
test_y_full = test_generator.y
test_coords_full = test_generator.coords
test_rasters_full = test_generator.get_raster_patches(test_coords_full)

# --- Set up explainers for each input type ---
# LIME for MLP (tabular) data
explainer_mlp = LimeTabularExplainer(
    training_data=test_mlp_full,
    feature_names=numeric_cols,
    mode='regression'
)

# LIME for Raster (image) data
explainer_raster = LimeImageExplainer()

# --- Explain a single random sample to avoid memory crash ---
print("\n--- Starting LIME analysis for a SINGLE random sample ---")
random_index = np.random.randint(0, len(test_mlp_full))

# Select the single sample data
sample_mlp = test_mlp_full[random_index].reshape(1, -1)
sample_gnn = test_gnn_full[random_index].reshape(1, -1)
sample_raster = test_rasters_full[random_index]
sample_y = test_y_full[random_index]

# --- Get LIME explanation for MLP features ---
print("\nExplaining MLP features...")
# LIME's explain_instance expects a 1D array for tabular data
lime_exp_mlp = explainer_mlp.explain_instance(
    data_row=test_mlp_full[random_index], 
    predict_fn=lambda x: predict_fn_for_lime(x, np.tile(sample_raster, (x.shape[0], 1, 1, 1)), sample_gnn),
    num_features=len(numeric_cols)
)
print("LIME Explanation for MLP features:")
for feature, weight in lime_exp_mlp.as_list():
    print(f"- {feature}: {weight:.4f}")


Analyzing GNN-MLP-Raster Fusion Model (Single Run)
Using a uniform patch size of 100 pixels for a 500m buffer.

--- Model Training ---


  self._warn_if_super_not_called()



--- Starting LIME analysis for a SINGLE random sample ---

Explaining MLP features...
LIME Explanation for MLP features:
- FeW > 0.76: 2.6344
- PbW > -0.03: 1.9294
- CuW > -0.13: 1.7351
- num_brick_field <= -0.29: 0.9112
- NiW > 0.44: 0.8733
- -0.30 < SiltW <= 0.88: -0.6198
- -1.03 < CrW <= 0.64: -0.2316
- -0.36 < MW <= -0.10: -0.1238
- -0.07 < AsW <= 1.41: -0.1234
- -0.02 < SandW <= 1.36: -0.1224
- -0.59 < hydro_dist_ind <= -0.03: -0.1178
- -0.32 < ClayW <= -0.11: 0.0718
- num_industry <= -0.26: 0.0669
- CdW > -0.28: 0.0396
- hydro_dist_brick <= -0.76: -0.0193


In [4]:
model.save("GNN_MLP.keras")

In [None]:
# --- LIME Importance ---
# This can be computationally intensive, so it's run on a sample
lime_importance_scores = calculate_lime_importance(
    model, 
    test_mlp_full, 
    test_gnn_full, 
    test_rasters_full, 
    numeric_cols, 
    raster_paths
)

print("\nAnalysis complete. All results are printed above.")



--- Starting LIME Feature Importance Analysis ---
Generating LIME explanation for sample 12...


In [None]:
# ==================== 0. Necessary Imports and Setup ==================== #
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense, Concatenate, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
import pickle # Import the pickle library for saving objects
import lime
import lime.lime_tabular
from tensorflow.python.ops.numpy_ops import np_config

# Set a consistent seed for reproducibility
tf.random.set_seed(42)
np.random.seed(42)

# Enable NumPy-like behavior in TensorFlow
np_config.enable_numpy_behavior()

# ==================== 1. Load Data ==================== #
# NOTE: This script assumes the following file paths are correct.
try:
    orig = pd.read_csv("../../data/WinterSeason1.csv")
    river_100 = pd.read_csv("../data/Samples_100W.csv")
except FileNotFoundError as e:
    print(f"Error: Required data file not found. Please check your file paths.")
    print(f"Details: {e}")
    sys.exit()

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# ==================== 2. Collect ALL Rasters and Metadata ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDWW/*.tif")

# Get the pixel resolution from the first raster to set a uniform patch size
try:
    with rasterio.open(raster_paths[0]) as src:
        pixel_size = src.transform.a
except IndexError:
    print("Error: No raster files found in the specified directories.")
    sys.exit()

# Create a dictionary to store raster metadata for fast access
raster_metadata = {}
for path in raster_paths:
    with rasterio.open(path) as src:
        raster_metadata[path] = {
            'transform': src.transform,
            'crs': src.crs,
            'width': src.width,
            'height': src.height
        }

# ==================== 3. Define a Custom Data Generator ==================== #
class DataGenerator(Sequence):
    """
    Custom Keras Sequence for generating batches of data.
    Handles three different input types: MLP features, GNN features,
    and raster image patches, loading rasters on-the-fly to save memory.
    """
    def __init__(self, mlp_data, gnn_data, y, coords, raster_paths, buffer_radius_m, pixel_size, batch_size=4, shuffle=True):
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.coords = coords
        self.raster_paths = raster_paths
        # Calculate the uniform patch size in pixels based on the buffer radius and pixel size
        # We need a square patch, so the size is 2 * radius / pixel_size
        self.patch_size = int(round((2 * buffer_radius_m) / pixel_size))
        # Ensure patch size is at least 1 and is an even number for easy centering
        if self.patch_size % 2 != 0:
            self.patch_size += 1
        self.patch_size = max(self.patch_size, 2)

        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))
    
    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
    
    def get_raster_patches(self, coords_batch):
        """
        Extracts a patch of raster data for each coordinate in the batch.
        Loads rasters on-the-fly to save memory and robustly handles boundaries.
        """
        patches_for_rasters = []
        for path in self.raster_paths:
            patches_for_this_raster = []
            try:
                with rasterio.open(path) as src:
                    for lon, lat in coords_batch:
                        # Get pixel coordinates
                        row, col = src.index(lon, lat)
                    
                        # Define a window to read around the pixel, handling boundaries
                        half_patch = self.patch_size // 2
                        left = int(col - half_patch)
                        top = int(row - half_patch)
                        right = int(col + half_patch)
                        bottom = int(row + half_patch)

                        # Create a new, empty array for the final padded patch
                        padded_patch = np.zeros((self.patch_size, self.patch_size), dtype='float32')

                        # Calculate the window in the raster's coordinate space to read from
                        # And the offset in the padded_patch to write to
                        read_left = max(0, left)
                        read_top = max(0, top)
                        read_right = min(src.width, right)
                        read_bottom = min(src.height, bottom)

                        # Check if the calculated window has a valid size
                        read_width = read_right - read_left
                        read_height = read_bottom - read_top
                    
                        if read_width > 0 and read_height > 0:
                            write_left = read_left - left
                            write_top = read_top - top
                            write_right = write_left + read_width
                            write_bottom = write_top + read_height

                            # Create the window object for rasterio to read from
                            window = Window(read_left, read_top, read_width, read_height)

                            # Read the data from the raster
                            patch_data = src.read(1, window=window)
                            # Place the read data into the padded patch
                            padded_patch[write_top:write_bottom, write_left:write_right] = patch_data
                    
                        patches_for_this_raster.append(padded_patch)
            
                # Stack the patches for this raster
                patches_for_rasters.append(np.stack(patches_for_this_raster, axis=0))
            except Exception as e:
                # This handles cases where a raster file might be missing or corrupted
                patches_for_rasters.append(np.zeros((len(coords_batch), self.patch_size, self.patch_size), dtype='float32'))


        # Stack all raster patches together
        final_patches = np.stack(patches_for_rasters, axis=-1)
        return final_patches

    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_mlp = self.mlp_data[batch_indices]
        batch_gnn = self.gnn_data[batch_indices, :]
        batch_y = self.y[batch_indices]
        batch_coords = self.coords[batch_indices]
        
        # Get raster data for the current batch
        batch_rasters = self.get_raster_patches(batch_coords)
        
        # Return a dictionary of inputs and the output
        return {"mlp_input": batch_mlp, "gnn_input": batch_gnn, "raster_input": batch_rasters}, batch_y

# ==================== 4. Define GNN-MLP-Raster Fusion Model ==================== #
def build_fusion_model(mlp_dim, gnn_dim, raster_patch_size, num_rasters):
    """
    Builds the multi-input Keras model with branches for MLP, GNN, and Rasters.
    """
    # Inputs for all branches
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    raster_input = Input(shape=(raster_patch_size, raster_patch_size, num_rasters), name="raster_input")

    # --- MLP Branch ---
    mlp_embedding = Dense(128, activation="relu", name="mlp_embedding_dense1")(mlp_input)
    mlp_embedding = Dense(64, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch ---
    gnn_embedding = Dense(128, activation="relu", name="gnn_embedding_dense1")(gnn_input)
    gnn_embedding = Dense(64, activation="relu", name="gnn_embedding")(gnn_embedding)
    
    # --- Raster Branch (using a simple CNN) ---
    raster_conv = Conv2D(32, (3, 3), activation="relu", name="raster_conv1")(raster_input)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_conv = Conv2D(64, (3, 3), activation="relu", name="raster_conv2")(raster_pool)
    raster_pool = MaxPooling2D((2, 2))(raster_conv)
    raster_flatten = Flatten()(raster_pool)
    raster_embedding = Dense(64, activation="relu", name="raster_embedding")(raster_flatten)

    # --- Concatenate Embeddings ---
    combined = Concatenate()([mlp_embedding, gnn_embedding, raster_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[mlp_input, gnn_input, raster_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

# ==================== 5. Define Evaluation & Importance Functions ==================== #
def calculate_smape(y_true, y_pred):
    """Calculates Symmetric Mean Absolute Percentage Error (SMAPE)."""
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero
    mask = denominator == 0
    smape_val = np.where(mask, 0, numerator / denominator)
    return 100 * np.mean(smape_val)

def evaluate_model(model, data_inputs, y_test, return_preds=False):
    """
    Evaluates the model on given data and returns R², RMSE, MAE, and SMAPE.
    Handles both Keras Generators and direct numpy arrays.
    """
    if isinstance(data_inputs, DataGenerator):
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    else:
        y_pred = model.predict(data_inputs, verbose=0).flatten()
    
    if return_preds:
        return y_pred
    else:
        # Align true labels with predictions if using a generator
        y_true_aligned = y_test[:len(y_pred)]
        r2 = r2_score(y_true_aligned, y_pred)
        rmse = np.sqrt(mean_squared_error(y_true_aligned, y_pred))
        mae = mean_absolute_error(y_true_aligned, y_pred)
        smape = calculate_smape(y_true_aligned, y_pred)
        return r2, rmse, mae, smape

def calculate_permutation_importance(model, mlp_data, gnn_data, raster_data, y_true, mlp_features, raster_features):
    """
    Calculates permutation feature importance for all individual features.
    """
    print("\n--- Starting Permutation Feature Importance Analysis ---")
    
    # Create the combined input for the model
    initial_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
    
    # Get baseline R² on the unshuffled data
    baseline_r2, _, _, _ = evaluate_model(model, initial_inputs, y_true)
    print(f"Baseline R²: {baseline_r2:.4f}")
    
    importance = {}
    
    # 1. Permute individual MLP features
    print("Permuting MLP features...")
    for i, feature in enumerate(mlp_features):
        shuffled_mlp_data = mlp_data.copy()
        np.random.shuffle(shuffled_mlp_data[:, i])
        shuffled_inputs = {"mlp_input": shuffled_mlp_data, "gnn_input": gnn_data, "raster_input": raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'MLP_{feature}'] = baseline_r2 - shuffled_r2
    
    # 2. Permute GNN input (as a single block)
    print("Permuting GNN features...")
    shuffled_gnn_data = gnn_data.copy()
    np.random.shuffle(shuffled_gnn_data)
    shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": shuffled_gnn_data, "raster_input": raster_data}
    shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
    importance['GNN'] = baseline_r2 - shuffled_r2
    
    # 3. Permute Raster inputs (each raster band as a feature)
    print("Permuting Raster features...")
    for i, feature in enumerate(raster_features):
        shuffled_raster_data = raster_data.copy()
        # Reshape the channel to a 2D array (samples, pixels) for easy shuffling
        reshaped_channel = shuffled_raster_data[:, :, :, i].reshape(shuffled_raster_data.shape[0], -1)
        # Shuffle each row independently to keep per-sample values
        np.random.shuffle(reshaped_channel)
        # Reshape back to the original shape
        shuffled_raster_data[:, :, :, i] = reshaped_channel.reshape(shuffled_raster_data.shape[0], shuffled_raster_data.shape[1], shuffled_raster_data.shape[2])
        shuffled_inputs = {"mlp_input": mlp_data, "gnn_input": gnn_data, "raster_input": shuffled_raster_data}
        shuffled_r2, _, _, _ = evaluate_model(model, shuffled_inputs, y_true)
        importance[f'Raster_{os.path.basename(feature)}'] = baseline_r2 - shuffled_r2
        
    return importance

def calculate_intrinsic_importance(model, mlp_features, raster_features):
    """
    Calculates intrinsic feature importance based on the weights of the model.
    This approach is more robust to the previous TypeError.
    """
    print("\n--- Starting Intrinsic Feature Importance Analysis ---")

    # === Branch-level Importance (L2 Norm of embedding layer weights) ===
    # This gives a single score for the overall importance of each data type.
    
    print("\nIntrinsic Importance (Branch-level):")
    
    # MLP Branch
    mlp_weights = model.get_layer("mlp_embedding").get_weights()[0]
    mlp_branch_importance = np.linalg.norm(mlp_weights)
    print(f"MLP Branch: {mlp_branch_importance.item():.4f}")

    # GNN Branch
    gnn_weights = model.get_layer("gnn_embedding").get_weights()[0]
    gnn_branch_importance = np.linalg.norm(gnn_weights)
    print(f"GNN Branch: {gnn_branch_importance.item():.4f}")
    
    # Raster Branch
    raster_weights = model.get_layer("raster_embedding").get_weights()[0]
    raster_branch_importance = np.linalg.norm(raster_weights)
    print(f"Raster Branch: {raster_branch_importance.item():.4f}")

    # === Feature-level Importance (Sum of absolute weights from input to first dense layer) ===
    # This provides a score for each individual feature within the MLP and Raster branches.
    
    print("\nIntrinsic Importance (Feature-level):")
    
    # MLP Features
    mlp_input_weights = model.get_layer("mlp_embedding_dense1").get_weights()[0]
    # Sum the absolute weights for each input feature across all its connections to the next layer
    mlp_feature_importance = np.sum(np.abs(mlp_input_weights), axis=1)
    
    print("\nMLP Features:")
    for feature, score in zip(mlp_features, mlp_feature_importance):
        print(f"  {feature}: {score.item():.4f}")
        
    # Raster Channels
    # Get the weights from the first convolutional layer
    raster_conv_weights = model.get_layer("raster_conv1").get_weights()[0]
    # Sum the absolute weights for each input channel
    raster_channel_importance = np.sum(np.abs(raster_conv_weights), axis=(0, 1, 3))
    
    print("\nRaster Channels:")
    for i, score in enumerate(raster_channel_importance):
        feature_name = os.path.basename(raster_features[i])
        print(f"  Raster_{feature_name}: {score.item():.4f}")
    
def calculate_lime_importance(model, test_mlp_data, test_gnn_data, test_raster_data, mlp_features, raster_features):
    """
    Calculates LIME (Local Interpretable Model-agnostic Explanations) importance.
    LIME is applied to a combined set of MLP and flattened raster features,
    as GNN input is context-dependent and not suitable for LIME.
    Note: LIME can be memory intensive, so we use a small number of samples.
    """
    print("\n--- Starting LIME Feature Importance Analysis ---")
    
    # Flatten the raster data to a 2D array
    flat_raster_data = test_raster_data.reshape(test_raster_data.shape[0], -1)
    
    # Combine MLP and flattened raster data for LIME
    combined_data = np.hstack([test_mlp_data, flat_raster_data])
    
    # Create the full list of feature names for the combined data
    raster_feature_names = [f"Raster_{os.path.basename(path)}_{i}" for path in raster_features for i in range(test_raster_data.shape[1] * test_raster_data.shape[2])]
    feature_names = list(mlp_features) + raster_feature_names
    
    # Define a prediction function that LIME can use
    def predict_fn(x):
        # Unpack the combined features back to their original shapes
        mlp_slice = x[:, :len(mlp_features)]
        raster_slice = x[:, len(mlp_features):].reshape(x.shape[0], test_raster_data.shape[1], test_raster_data.shape[2], len(raster_features))
        
        # We need a dummy GNN input for the model prediction
        dummy_gnn = np.zeros((x.shape[0], test_gnn_data.shape[1]))
        
        # Return the model's predictions (LIME expects a single value per sample)
        return model.predict({"mlp_input": mlp_slice, "gnn_input": dummy_gnn, "raster_input": raster_slice}, verbose=0)
    
    # Initialize the LIME explainer
    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=combined_data, 
        feature_names=feature_names, 
        class_names=["RI Prediction"], 
        mode='regression'
    )
    
    # Choose a few samples to explain
    num_samples = 3 # Reduced to 3 to avoid memory issues
    sample_indices = np.random.choice(range(len(test_mlp_data)), num_samples, replace=False)
    
    lime_importance_scores = {}
    
    for idx in sample_indices:
        print(f"Generating LIME explanation for sample {idx}...")
        explanation = explainer.explain_instance(
            data_row=combined_data[idx], 
            predict_fn=predict_fn, 
            num_features=10 # Explain the top 10 most important features, as requested
        )
        for feature, weight in explanation.as_list():
            if feature not in lime_importance_scores:
                lime_importance_scores[feature] = []
            lime_importance_scores[feature].append(abs(weight))
            
    # Aggregate and average the importance scores
    avg_lime_importance = {
        feature: np.mean(scores) for feature, scores in lime_importance_scores.items()
    }
    
    # Sort and print the top 10 features
    print("\nTop 10 LIME Features (Average Absolute Weight):")
    sorted_lime = sorted(avg_lime_importance.items(), key=lambda item: item[1], reverse=True)
    for feature, score in sorted_lime[:10]:
        print(f"{feature}: {score:.4f}")
    
    return avg_lime_importance

# ==================== 6. Main Analysis without K-Fold CV ==================== #

print("\n" + "="*80)
print("Analyzing GNN-MLP-Raster Fusion Model (Single Run)")
print(f"Using a uniform patch size of {int(round((2 * 500) / pixel_size))} pixels for a 500m buffer.")
print("="*80)

# Combine all data
full_data = pd.concat([orig, river_100], ignore_index=True).sample(frac=1, random_state=42).reset_index(drop=True)
full_coords = full_data[['Long','Lat']].values
full_y = full_data['RI'].values
full_mlp_data = full_data[numeric_cols].values

# Pre-process MLP data with StandardScaler
scaler = StandardScaler()
full_mlp_data = scaler.fit_transform(full_mlp_data)

# Split data into training and testing sets (e.g., 80% train, 20% test)
train_mlp, test_mlp, train_coords, test_coords, y_train, y_test = train_test_split(
    full_mlp_data, full_coords, full_y, test_size=0.2, random_state=42
)

# Prepare GNN input (adjacency matrix based on distances)
dist_mat_train = distance_matrix(train_coords, train_coords)
gnn_train = np.exp(-dist_mat_train / 10)
    
dist_mat_test_train = distance_matrix(test_coords, train_coords)
gnn_test = np.exp(-dist_mat_test_train / 10)

# Clean up memory
del dist_mat_train, dist_mat_test_train
gc.collect()

# Define patch size and number of rasters
buffer_radius_m = 500
raster_patch_size = int(round((2 * buffer_radius_m) / pixel_size))
if raster_patch_size % 2 != 0:
    raster_patch_size += 1
raster_patch_size = max(raster_patch_size, 2)
num_rasters = len(raster_paths)

# Build and compile the model
model = build_fusion_model(mlp_dim=train_mlp.shape[1], gnn_dim=gnn_train.shape[1], 
                             raster_patch_size=raster_patch_size, num_rasters=num_rasters)

# Print model summary for inspection
model.summary()
    
# Create data generators for training and testing
train_generator = DataGenerator(
    mlp_data=train_mlp, gnn_data=gnn_train, y=y_train, coords=train_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=True
)

test_generator = DataGenerator(
    mlp_data=test_mlp, gnn_data=gnn_test, y=y_test, coords=test_coords,
    raster_paths=raster_paths, buffer_radius_m=buffer_radius_m, pixel_size=pixel_size, batch_size=4, shuffle=False
)
    
# Train the model
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

print("\n--- Model Training ---")
history = model.fit(
    train_generator,
    epochs=1,
    verbose=1,
    callbacks=[early_stopping],
    validation_data=test_generator
)

# Evaluate on the test data
r2_test, rmse_test, mae_test, smape_test = evaluate_model(model, test_generator, y_test)
    
print("\n" + "="*80)
print("Final Model Performance on Test Set")
print("="*80)
print(f"R²: {r2_test:.4f}")
print(f"RMSE: {rmse_test:.4f}")
print(f"MAE: {mae_test:.4f}")
print(f"SMAPE: {smape_test:.4f}%")

# ==================== 7. Feature Importance Analysis ==================== #

# --- Prepare data for importance functions (needs to be full numpy arrays) ---
# Get all test data from the generator
test_mlp_full = test_generator.mlp_data
test_gnn_full = test_generator.gnn_data
test_y_full = test_generator.y
test_coords_full = test_generator.coords
test_rasters_full = test_generator.get_raster_patches(test_coords_full)

# --- Permutation Importance ---
permutation_importance_scores = calculate_permutation_importance(
    model, 
    test_mlp_full, 
    test_gnn_full, 
    test_rasters_full, 
    test_y_full, 
    numeric_cols, 
    raster_paths
)
print("\n--- Summary of Permutation Importance ---")
sorted_perm_importance = sorted(permutation_importance_scores.items(), key=lambda item: item[1], reverse=True)
for feature, score in sorted_perm_importance:
    print(f"{feature}: {score:.4f}")

# --- Intrinsic Importance ---
calculate_intrinsic_importance(model, numeric_cols, raster_paths)

# --- LIME Importance ---
# This can be computationally intensive, so it's run on a sample
lime_importance_scores = calculate_lime_importance(
    model, 
    test_mlp_full, 
    test_gnn_full, 
    test_rasters_full, 
    numeric_cols, 
    raster_paths
)

print("\nAnalysis complete. All results are printed above.")
