In [None]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Concatenate,
    Dropout,
    Layer,
    Lambda,
    GlobalAveragePooling2D,
    Reshape,
    Multiply
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc # Import garbage collector
import sys
from io import StringIO # To capture print output
import pickle # For saving dictionaries and other objects

# Define the single buffer size to use
BUFFER_METERS = 500

# ==================== 1. Load Data ==================== #
# NOTE: The data loading logic remains the same.
# Replace with your actual data paths if needed
orig = pd.read_csv("../../data/RainySeason.csv")
river_100 = pd.read_csv("../data/Samples_100.csv")

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)

# ==================== 2. Collect ALL Rasters ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")

print(f"Using {len(raster_paths)} raster layers for CNN input.")
for r in raster_paths:
    print("  -", os.path.basename(r))

# ==================== 3. Create a Custom Data Generator ==================== #
def extract_patch_for_generator(coords, raster_files, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height):
    """
    Extracts a batch of patches from rasters for a given set of coordinates.
    This function is optimized to be called by the data generator for each batch.
    """
    patches = []
    # Loop through each coordinate pair in the batch
    for lon, lat in coords:
        channels = []
        # Loop through each raster file to get a single patch for each raster
        for rfile in raster_files:
            with rasterio.open(rfile) as src:
                try:
                    row, col = src.index(lon, lat)
                    win = Window(col - buffer_pixels_x, row - buffer_pixels_y, patch_width, patch_height)
                    arr = src.read(1, window=win, boundless=True, fill_value=0)
                    arr = arr.astype(np.float32)

                    if np.nanmax(arr) != 0:
                        arr /= np.nanmax(arr)
                except Exception as e:
                    print(f"Error processing {rfile} for coordinates ({lon}, {lat}): {e}")
                    arr = np.zeros((patch_width, patch_height), dtype=np.float32)
            channels.append(arr)
        patches.append(np.stack(channels, axis=-1))
    
    return np.array(patches)

class DataGenerator(Sequence):
    def __init__(self, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.coords = coords
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.raster_paths = raster_paths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.buffer_meters = buffer_meters

        # Pre-calculate patch size from the first raster
        with rasterio.open(raster_paths[0]) as src:
            res_x, res_y = src.res
            self.buffer_pixels_x = int(self.buffer_meters / res_x)
            self.buffer_pixels_y = int(self.buffer_meters / res_y)
            self.patch_width = 2 * self.buffer_pixels_x
            self.patch_height = 2 * self.buffer_pixels_y

        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_coords = self.coords[batch_indices]
        batch_mlp = self.mlp_data[batch_indices]
        
        # Slice the GNN adjacency matrix for the current batch
        batch_gnn = self.gnn_data[batch_indices, :]

        batch_y = self.y[batch_indices]

        # Extract CNN patches for the current batch
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            self.raster_paths,
            self.buffer_pixels_x,
            self.buffer_pixels_y,
            self.patch_width,
            self.patch_height
        )

        # Return a tuple of inputs and the target, which Keras expects
        return (batch_cnn, batch_mlp, batch_gnn), batch_y


# ==================== 4. Prepare GNN & MLP Input (only once) ==================== #
coords_train = train_combined[['Long','Lat']].values
coords_test = test_orig[['Long','Lat']].values
dist_mat_train = distance_matrix(coords_train, coords_train)
gnn_train = np.exp(-dist_mat_train/10)
dist_mat_test_train = distance_matrix(coords_test, coords_train)
gnn_test = np.exp(-dist_mat_test_train/10)

scaler = StandardScaler()
mlp_train = scaler.fit_transform(train_combined[numeric_cols])
mlp_test = scaler.transform(test_orig[numeric_cols])
y_train = train_combined['RI'].values
y_test = test_orig['RI'].values

# ==================== 5. Define Custom Attention Layers ==================== #

class SpatialAttention(Layer):
    """
    A custom layer to apply spatial attention to a feature map.
    It generates a spatial attention map and multiplies it with the input.
    """
    def __init__(self, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.conv1 = Conv2D(1, (1, 1), activation='sigmoid')

    def call(self, inputs):
        # Squeeze the channels and generate a 2D attention map
        attention_map = self.conv1(inputs)
        # Multiply the input feature map by the attention map
        return Multiply()([inputs, attention_map])

class FeatureAttention(Layer):
    """
    A custom layer to apply feature-wise attention.
    It learns a weight for each feature channel and multiplies it with the input.
    Inspired by Squeeze-and-Excitation networks.
    """
    def __init__(self, reduction_ratio=16, **kwargs):
        super(FeatureAttention, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        super(FeatureAttention, self).build(input_shape)
        if len(input_shape) == 4: # CNN output
            self.avg_pool = GlobalAveragePooling2D()
            self.dense1 = Dense(units=input_shape[-1] // self.reduction_ratio, activation='relu')
            self.dense2 = Dense(units=input_shape[-1], activation='sigmoid')
            self.reshape_output = Reshape((1, 1, input_shape[-1]))
        else: # MLP or GNN output
            self.dense1 = Dense(units=input_shape[-1] // self.reduction_ratio, activation='relu')
            self.dense2 = Dense(units=input_shape[-1], activation='sigmoid')

    def call(self, inputs):
        if len(inputs.shape) == 4: # CNN branch
            x = self.avg_pool(inputs)
            x = self.dense1(x)
            x = self.dense2(x)
            x = self.reshape_output(x)
        else: # MLP or GNN branch
            x = self.dense1(inputs)
            x = self.dense2(x)
            
        return Multiply()([inputs, x])

# ==================== 6. Define the Dual Attention Model ==================== #
def build_dual_attention_model(patch_shape, gnn_dim, mlp_dim):
    # Inputs for all branches
    cnn_input = Input(shape=patch_shape, name="cnn_input")
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- CNN Branch with Spatial and Feature Attention ---
    cnn_branch = Conv2D(32, (3,3), activation="relu", padding="same")(cnn_input)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch = Conv2D(64, (3,3), activation="relu", padding="same")(cnn_branch)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    
    # Spatial Attention
    cnn_spatial_attn = SpatialAttention()(cnn_branch)
    
    # Feature Attention
    cnn_feature_attn = FeatureAttention()(cnn_spatial_attn)
    
    # Flatten and get embedding
    cnn_embedding = Flatten()(cnn_feature_attn)
    cnn_embedding = Dense(128, activation="relu", name="cnn_embedding")(cnn_embedding)

    # --- MLP Branch with Embedding ---
    mlp_embedding = Dense(64, activation="relu")(mlp_input)
    mlp_embedding = Dense(32, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch with Feature Attention and Embedding ---
    gnn_branch = Dense(64, activation="relu")(gnn_input)
    
    # Feature Attention
    gnn_feature_attn = FeatureAttention()(gnn_branch)
    gnn_embedding = Dense(32, activation="relu", name="gnn_embedding")(gnn_feature_attn)

    # --- Attention Fusion ---
    # Concatenate all embeddings
    combined_embedding = Concatenate(name="combined_embedding")([cnn_embedding, mlp_embedding, gnn_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined_embedding)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[cnn_input, mlp_input, gnn_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

def evaluate_model(model, coords_test, mlp_test, gnn_test_matrix, y_test, raster_paths, buffer_meters, batch_size=4, return_preds=False):
    num_samples = len(y_test)
    y_pred_list = []
    
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(buffer_meters / res_x)
        buffer_pixels_y = int(buffer_meters / res_y)
        patch_width = 2 * buffer_pixels_x
        patch_height = 2 * buffer_pixels_y

    for i in range(0, num_samples, batch_size):
        batch_coords = coords_test[i:i+batch_size]
        batch_mlp = mlp_test[i:i+batch_size]
        
        batch_gnn = gnn_test_matrix[i:i+batch_size, :]
        batch_y = y_test[i:i+batch_size]

        batch_cnn = extract_patch_for_generator(
            batch_coords,
            raster_paths,
            buffer_pixels_x,
            buffer_pixels_y,
            patch_width,
            patch_height
        )
        
        y_pred_list.append(model.predict((batch_cnn, batch_mlp, batch_gnn)).flatten())
        
    y_pred = np.concatenate(y_pred_list)
    
    if return_preds:
        return y_pred
    else:
        r2 = r2_score(y_test, y_pred)
        rmse = np.sqrt(mean_squared_error(y_test, y_pred))
        return r2, rmse

# ==================== Run the Analysis ==================== #
# Capture all print statements to a string
old_stdout = sys.stdout
sys.stdout = captured_output = StringIO()

print("\n" + "="*80)
print(f"Analyzing for BUFFER_METERS = {BUFFER_METERS}m")
print("="*80)

batch_size = 4
gnn_input_dim = len(coords_train)

# Calculate CNN patch shape based on the current buffer size
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    patch_width = 2 * buffer_pixels_x
    cnn_patch_shape = (patch_width, patch_width, len(raster_paths))

model = build_dual_attention_model(cnn_patch_shape, gnn_input_dim, mlp_train.shape[1])
model.summary(print_fn=lambda x: captured_output.write(x + '\n')) # Capture model summary

# ==================== 7. Create Data Generators ==================== #
train_generator = DataGenerator(
    coords=coords_train,
    mlp_data=mlp_train,
    gnn_data=gnn_train,
    y=y_train,
    raster_paths=raster_paths,
    buffer_meters=BUFFER_METERS,
    batch_size=batch_size,
    shuffle=True
)

# ==================== 8. Train Model ==================== #
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

print("\nStarting model training...")
history = model.fit(
    train_generator,
    epochs=100,
    verbose=1,
    callbacks=[early_stopping],
    validation_data=train_generator
)
print("Training complete.")

# ==================== 9. Evaluate ==================== #
y_pred_train = model.predict(train_generator).flatten()
r2_train = r2_score(y_train[:len(y_pred_train)], y_pred_train)
rmse_train = np.sqrt(mean_squared_error(y_train[:len(y_pred_train)], y_pred_train))

# Get test predictions for saving as .npy
y_pred_test = evaluate_model(model, coords_test, mlp_test, gnn_test, y_test, raster_paths, buffer_meters=BUFFER_METERS, batch_size=batch_size, return_preds=True)
r2_test = r2_score(y_test, y_pred_test)
rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))

print(f"\n Dual Attention Model Performance ({BUFFER_METERS}m):")
print(f"R² Train: {r2_train:.4f} | RMSE Train: {rmse_train:.4f}")
print(f"R² Test: {r2_test:.4f} | RMSE Test: {rmse_test:.4f}")

# ==================== 10. Feature Importance Analysis ==================== #
print("\n" + "-"*50)
print(f"Feature Importance Analysis for {BUFFER_METERS}m")
print("-"*50)

# --- 10.1 Combined Feature Importance (by Model Branch) ---
y_pred_baseline = y_pred_test
baseline_r2 = r2_test
print(f"\nBaseline Performance on Test Set: R² = {baseline_r2:.4f}")

# Ablate CNN branch
with rasterio.open(raster_paths[0]) as src:
    res_x, res_y = src.res
    buffer_pixels_x = int(BUFFER_METERS / res_x)
    buffer_pixels_y = int(BUFFER_METERS / res_y)
    patch_width = 2 * buffer_pixels_x
    patch_height = 2 * buffer_pixels_y

cnn_test_ablated = np.zeros_like(extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
))
y_pred_cnn_ablated = model.predict((cnn_test_ablated, mlp_test, gnn_test)).flatten()
r2_cnn_ablated = r2_score(y_test, y_pred_cnn_ablated)
importance_cnn = baseline_r2 - r2_cnn_ablated

# Ablate MLP branch
mlp_test_ablated = np.zeros_like(mlp_test)
y_pred_mlp_ablated = model.predict((extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
), mlp_test_ablated, gnn_test)).flatten()
r2_mlp_ablated = r2_score(y_test, y_pred_mlp_ablated)
importance_mlp = baseline_r2 - r2_mlp_ablated

# Ablate GNN branch
gnn_test_ablated = np.zeros_like(gnn_test)
y_pred_gnn_ablated = model.predict((extract_patch_for_generator(
    coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
), mlp_test, gnn_test_ablated)).flatten()
r2_gnn_ablated = r2_score(y_test, y_pred_gnn_ablated)
importance_gnn = baseline_r2 - r2_gnn_ablated

print("\n--- Combined Feature Importance (by Model Branch) ---")
print(f"CNN Branch Importance (R² drop): {importance_cnn:.4f}")
print(f"MLP Branch Importance (R² drop): {importance_mlp:.4f}")
print(f"GNN Branch Importance (R² drop): {importance_gnn:.4f}")

# --- 10.2 MLP Feature Importance (Permutation-based) ---
mlp_feature_importance = {}
for i, feature_name in enumerate(numeric_cols):
    mlp_test_shuffled = np.copy(mlp_test)
    np.random.shuffle(mlp_test_shuffled[:, i])
    
    y_pred_shuffled = model.predict((extract_patch_for_generator(
        coords_test, raster_paths, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height
    ), mlp_test_shuffled, gnn_test)).flatten()
    r2_shuffled = r2_score(y_test, y_pred_shuffled)
    
    importance = baseline_r2 - r2_shuffled
    mlp_feature_importance[feature_name] = importance

print("\n--- MLP Feature Importance (Permutation-based) ---")
sorted_importance = sorted(mlp_feature_importance.items(), key=lambda item: item[1], reverse=True)
for feature, importance in sorted_importance:
    print(f"{feature:<20}: {importance:.4f}")
    
# Garbage collect to free up memory
del history, train_generator
gc.collect()

# ==================== 11. Save all info to a folder ==================== #
# Restore standard output
sys.stdout = old_stdout
printed_output = captured_output.getvalue()

output_folder = "dual_attention_analysis"
os.makedirs(output_folder, exist_ok=True)
print(f"\nCreating folder: '{output_folder}' and saving results...")

# Save the trained model in the .keras format
model_path = os.path.join(output_folder, "dual_attention.keras")
model.save(model_path)
print(f"Model saved to: {model_path}")

# Save the MLP feature importance to a .pkl file
mlp_importance_path = os.path.join(output_folder, "mlp_feature_importance.pkl")
with open(mlp_importance_path, 'wb') as f:
    pickle.dump(mlp_feature_importance, f)
print(f"MLP feature importance saved to: {mlp_importance_path}")

# Save all relevant data to .npy files
np.save(os.path.join(output_folder, "coords_train.npy"), coords_train)
np.save(os.path.join(output_folder, "coords_test.npy"), coords_test)
np.save(os.path.join(output_folder, "mlp_train.npy"), mlp_train)
np.save(os.path.join(output_folder, "mlp_test.npy"), mlp_test)
np.save(os.path.join(output_folder, "gnn_train.npy"), gnn_train)
np.save(os.path.join(output_folder, "gnn_test.npy"), gnn_test)
np.save(os.path.join(output_folder, "y_train.npy"), y_train)
np.save(os.path.join(output_folder, "y_test.npy"), y_test)
np.save(os.path.join(output_folder, "y_pred_train.npy"), y_pred_train)
np.save(os.path.join(output_folder, "y_pred_test.npy"), y_pred_test)
print(f"Coordinates, scaled data, GNN matrices, and labels/predictions saved as .npy files.")

# Save the printed output to a text file
output_path = os.path.join(output_folder, "analysis_output.txt")
with open(output_path, "w") as f:
    f.write(printed_output)
print(f"Analysis results saved to: {output_path}")

print("All information successfully saved.")

In [1]:
import pandas as pd
import numpy as np
import glob
import os
import rasterio
from rasterio.windows import Window
from scipy.spatial import distance_matrix
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, mean_absolute_error
from sklearn.model_selection import KFold
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Input,
    Conv2D,
    MaxPooling2D,
    Flatten,
    Dense,
    Concatenate,
    Dropout,
    Layer,
    GlobalAveragePooling2D,
    Reshape,
    Multiply
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.utils import Sequence
import tensorflow as tf
import gc
import sys
from io import StringIO
import pickle

# Define the single buffer size to use
BUFFER_METERS = 500
N_SPLITS = 5 # Number of folds for cross-validation

# ==================== 1. Load Data ==================== #
# NOTE: The data loading logic remains the same.
# Replace with your actual data paths if needed
orig = pd.read_csv("../../data/RainySeason.csv")
river_100 = pd.read_csv("../data/Samples_100.csv")

drop_cols = ['Stations','River','Lat','Long','geometry']
numeric_cols = orig.drop(columns=drop_cols).columns.drop('RI')

# Train-test split
train_orig = orig.sample(10, random_state=42)
test_orig = orig.drop(train_orig.index)
train_combined = pd.concat([river_100, train_orig], ignore_index=True)

# ==================== 2. Collect ALL Rasters ==================== #
raster_paths = []
raster_paths += glob.glob("../CalIndices/*.tif")
raster_paths += glob.glob("../LULCMerged/*.tif")
raster_paths += glob.glob("../IDW/*.tif")

print(f"Using {len(raster_paths)} raster layers for CNN input.")
for r in raster_paths:
    print("  -", os.path.basename(r))

# ==================== 3. Define Metric Functions ==================== #

def smape(y_true, y_pred):
    """
    Calculates the Symmetric Mean Absolute Percentage Error (SMAPE).
    """
    numerator = np.abs(y_pred - y_true)
    denominator = (np.abs(y_true) + np.abs(y_pred)) / 2
    # Avoid division by zero by setting the result to 0 where denominator is 0
    return np.mean(np.where(denominator == 0, 0, numerator / denominator)) * 100

# ==================== 4. Create a Custom Data Generator ==================== #

def extract_patch_for_generator(coords, raster_files, buffer_pixels_x, buffer_pixels_y, patch_width, patch_height):
    """
    Extracts a batch of patches from rasters for a given set of coordinates.
    This function is optimized to be called by the data generator for each batch.
    """
    patches = []
    # Loop through each coordinate pair in the batch
    for lon, lat in coords:
        channels = []
        # Loop through each raster file to get a single patch for each raster
        for rfile in raster_files:
            with rasterio.open(rfile) as src:
                try:
                    row, col = src.index(lon, lat)
                    win = Window(col - buffer_pixels_x, row - buffer_pixels_y, patch_width, patch_height)
                    arr = src.read(1, window=win, boundless=True, fill_value=0)
                    arr = arr.astype(np.float32)

                    if np.nanmax(arr) != 0:
                        arr /= np.nanmax(arr)
                except Exception as e:
                    print(f"Error processing {rfile} for coordinates ({lon}, {lat}): {e}")
                    arr = np.zeros((patch_width, patch_height), dtype=np.float32)
            channels.append(arr)
        patches.append(np.stack(channels, axis=-1))
    
    return np.array(patches)

class DataGenerator(Sequence):
    def __init__(self, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.coords = coords
        self.mlp_data = mlp_data
        self.gnn_data = gnn_data
        self.y = y
        self.raster_paths = raster_paths
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(self.y))
        self.buffer_meters = buffer_meters

        # Pre-calculate patch size from the first raster
        with rasterio.open(raster_paths[0]) as src:
            res_x, res_y = src.res
            self.buffer_pixels_x = int(self.buffer_meters / res_x)
            self.buffer_pixels_y = int(self.buffer_meters / res_y)
            self.patch_width = 2 * self.buffer_pixels_x
            self.patch_height = 2 * self.buffer_pixels_y

        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.y) / self.batch_size))

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)
            
    def __getitem__(self, index):
        # Get batch indices
        batch_indices = self.indices[index * self.batch_size:(index + 1) * self.batch_size]

        # Get batch data
        batch_coords = self.coords[batch_indices]
        batch_mlp = self.mlp_data[batch_indices]
        
        # Slice the GNN adjacency matrix for the current batch
        batch_gnn = self.gnn_data[batch_indices, :]

        batch_y = self.y[batch_indices]

        # Extract CNN patches for the current batch
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            self.raster_paths,
            self.buffer_pixels_x,
            self.buffer_pixels_y,
            self.patch_width,
            self.patch_height
        )

        # Return a tuple of inputs and the target, which Keras expects
        return (batch_cnn, batch_mlp, batch_gnn), batch_y

# ==================== 5. Define Custom Attention Layers ==================== #

class SpatialAttention(Layer):
    """
    A custom layer to apply spatial attention to a feature map.
    It generates a spatial attention map and multiplies it with the input.
    """
    def __init__(self, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.conv1 = Conv2D(1, (1, 1), activation='sigmoid')

    def call(self, inputs):
        # Squeeze the channels and generate a 2D attention map
        attention_map = self.conv1(inputs)
        # Multiply the input feature map by the attention map
        return Multiply()([inputs, attention_map])

class FeatureAttention(Layer):
    """
    A custom layer to apply feature-wise attention.
    It learns a weight for each feature channel and multiplies it with the input.
    Inspired by Squeeze-and-Excitation networks.
    """
    def __init__(self, reduction_ratio=16, **kwargs):
        super(FeatureAttention, self).__init__(**kwargs)
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        super(FeatureAttention, self).build(input_shape)
        if len(input_shape) == 4: # CNN output
            self.avg_pool = GlobalAveragePooling2D()
            self.dense1 = Dense(units=input_shape[-1] // self.reduction_ratio, activation='relu')
            self.dense2 = Dense(units=input_shape[-1], activation='sigmoid')
            self.reshape_output = Reshape((1, 1, input_shape[-1]))
        else: # MLP or GNN output
            self.dense1 = Dense(units=input_shape[-1] // self.reduction_ratio, activation='relu')
            self.dense2 = Dense(units=input_shape[-1], activation='sigmoid')

    def call(self, inputs):
        if len(inputs.shape) == 4: # CNN branch
            x = self.avg_pool(inputs)
            x = self.dense1(x)
            x = self.dense2(x)
            x = self.reshape_output(x)
        else: # MLP or GNN branch
            x = self.dense1(inputs)
            x = self.dense2(x)
            
        return Multiply()([inputs, x])

# ==================== 6. Define the Dual Attention Model ==================== #
def build_dual_attention_model(patch_shape, gnn_dim, mlp_dim):
    # Inputs for all branches
    cnn_input = Input(shape=patch_shape, name="cnn_input")
    mlp_input = Input(shape=(mlp_dim,), name="mlp_input")
    gnn_input = Input(shape=(gnn_dim,), name="gnn_input")
    
    # --- CNN Branch with Spatial and Feature Attention ---
    cnn_branch = Conv2D(32, (3,3), activation="relu", padding="same")(cnn_input)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    cnn_branch = Conv2D(64, (3,3), activation="relu", padding="same")(cnn_branch)
    cnn_branch = MaxPooling2D((2,2))(cnn_branch)
    
    # Spatial Attention
    cnn_spatial_attn = SpatialAttention()(cnn_branch)
    
    # Feature Attention
    cnn_feature_attn = FeatureAttention()(cnn_spatial_attn)
    
    # Flatten and get embedding
    cnn_embedding = Flatten()(cnn_feature_attn)
    cnn_embedding = Dense(128, activation="relu", name="cnn_embedding")(cnn_embedding)

    # --- MLP Branch with Embedding ---
    mlp_embedding = Dense(64, activation="relu")(mlp_input)
    mlp_embedding = Dense(32, activation="relu", name="mlp_embedding")(mlp_embedding)

    # --- GNN Branch with Feature Attention and Embedding ---
    gnn_branch = Dense(64, activation="relu")(gnn_input)
    
    # Feature Attention
    gnn_feature_attn = FeatureAttention()(gnn_branch)
    gnn_embedding = Dense(32, activation="relu", name="gnn_embedding")(gnn_feature_attn)

    # --- Attention Fusion ---
    # Concatenate all embeddings
    combined_embedding = Concatenate(name="combined_embedding")([cnn_embedding, mlp_embedding, gnn_embedding])
    
    # Final dense layers for prediction
    f = Dense(128, activation="relu")(combined_embedding)
    f = Dropout(0.4)(f)
    f = Dense(64, activation="relu")(f)
    output = Dense(1, activation="linear", name="final_output")(f)

    # Build and compile the model
    model = Model(inputs=[cnn_input, mlp_input, gnn_input], outputs=output)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss="mse")
    return model

def evaluate_model(model, coords, mlp_data, gnn_data, y, raster_paths, buffer_meters, batch_size=4):
    """
    Evaluates a model on the given data and returns a dictionary of metrics.
    """
    num_samples = len(y)
    y_pred_list = []
    
    # Get patch dimensions
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(buffer_meters / res_x)
        buffer_pixels_y = int(buffer_meters / res_y)
        patch_width = 2 * buffer_pixels_x
        patch_height = 2 * buffer_pixels_y

    for i in range(0, num_samples, batch_size):
        batch_coords = coords[i:i+batch_size]
        batch_mlp = mlp_data[i:i+batch_size]
        batch_gnn = gnn_data[i:i+batch_size, :]
        
        batch_cnn = extract_patch_for_generator(
            batch_coords,
            raster_paths,
            buffer_pixels_x,
            buffer_pixels_y,
            patch_width,
            patch_height
        )
        
        y_pred_list.append(model.predict((batch_cnn, batch_mlp, batch_gnn)).flatten())
        
    y_pred = np.concatenate(y_pred_list)
    
    r2 = r2_score(y, y_pred)
    mae = mean_absolute_error(y, y_pred)
    rmse = np.sqrt(mean_squared_error(y, y_pred))
    smape_val = smape(y, y_pred)

    return {
        'R2': r2,
        'MAE': mae,
        'RMSE': rmse,
        'SMAPE': smape_val
    }

# ==================== 7. Run K-Fold Cross-Validation ==================== #


print("="*80)
print(f"Starting {N_SPLITS}-Fold Cross-Validation...")
print("="*80)

# Create a folder to save models
model_save_dir = "models/dual_attention"
os.makedirs(model_save_dir, exist_ok=True)
print(f"Models will be saved in: '{model_save_dir}'")

# Prepare data for K-Fold
combined_indices = np.arange(len(train_combined))
coords_all = train_combined[['Long','Lat']].values
data_all = train_combined[numeric_cols].values
gnn_input_all = np.exp(-distance_matrix(coords_all, coords_all)/10)
y_all = train_combined['RI'].values
batch_size = 4

# Store metrics for each fold
fold_metrics = []
test_metrics = []

# Initialize K-Fold
kf = KFold(n_splits=N_SPLITS, shuffle=True, random_state=42)

for fold, (train_indices, val_indices) in enumerate(kf.split(combined_indices)):
    print(f"\n--- Fold {fold+1}/{N_SPLITS} ---")

    # Split data for the current fold
    coords_train_fold = coords_all[train_indices]
    coords_val_fold = coords_all[val_indices]
    mlp_train_fold = data_all[train_indices]
    mlp_val_fold = data_all[val_indices]
    y_train_fold = y_all[train_indices]
    y_val_fold = y_all[val_indices]

    # Scale MLP data and prepare GNN matrices for the current fold
    scaler_fold = StandardScaler()
    mlp_train_scaled = scaler_fold.fit_transform(mlp_train_fold)
    mlp_val_scaled = scaler_fold.transform(mlp_val_fold)
    
    dist_mat_train_fold = distance_matrix(coords_train_fold, coords_train_fold)
    gnn_train_fold = np.exp(-dist_mat_train_fold/10)
    
    dist_mat_val_fold = distance_matrix(coords_val_fold, coords_train_fold)
    gnn_val_fold = np.exp(-dist_mat_val_fold/10)

    # Re-initialize and compile the model for each fold
    with rasterio.open(raster_paths[0]) as src:
        res_x, res_y = src.res
        buffer_pixels_x = int(BUFFER_METERS / res_x)
        patch_width = 2 * buffer_pixels_x
        cnn_patch_shape = (patch_width, patch_width, len(raster_paths))
    
    model = build_dual_attention_model(cnn_patch_shape, len(coords_train_fold), mlp_train_fold.shape[1])
    
    # Create Data Generators for the current fold
    train_generator = DataGenerator(
        coords=coords_train_fold,
        mlp_data=mlp_train_scaled,
        gnn_data=gnn_train_fold,
        y=y_train_fold,
        raster_paths=raster_paths,
        buffer_meters=BUFFER_METERS,
        batch_size=batch_size,
        shuffle=True
    )
    
    val_generator = DataGenerator(
        coords=coords_val_fold,
        mlp_data=mlp_val_scaled,
        gnn_data=gnn_val_fold,
        y=y_val_fold,
        raster_paths=raster_paths,
        buffer_meters=BUFFER_METERS,
        batch_size=batch_size,
        shuffle=False
    )
    
    early_stopping = EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True
    )

    print("Starting model training...")
    history = model.fit(
        train_generator,
        epochs=1,
        verbose=1,
        callbacks=[early_stopping],
        validation_data=val_generator
    )
    print("Training complete.")

    # Evaluate on the validation set
    fold_result = evaluate_model(model, coords_val_fold, mlp_val_scaled, gnn_val_fold, y_val_fold, raster_paths, BUFFER_METERS, batch_size)
    fold_metrics.append(fold_result)
    print("Validation Metrics:")
    print(f"R²: {fold_result['R2']:.4f}")
    print(f"MAE: {fold_result['MAE']:.4f}")
    print(f"RMSE: {fold_result['RMSE']:.4f}")
    print(f"SMAPE: {fold_result['SMAPE']:.4f}%")

    # Prepare and evaluate on the independent test set
    dist_mat_test_train = distance_matrix(test_orig[['Long','Lat']].values, coords_train_fold)
    gnn_test_fold = np.exp(-dist_mat_test_train/10)
    
    mlp_test_scaled = scaler_fold.transform(test_orig[numeric_cols].values)

    test_result = evaluate_model(model, test_orig[['Long','Lat']].values, mlp_test_scaled, gnn_test_fold, test_orig['RI'].values, raster_paths, BUFFER_METERS, batch_size)
    test_metrics.append(test_result)
    print("\nIndependent Test Set Metrics:")
    print(f"R²: {test_result['R2']:.4f}")
    print(f"MAE: {test_result['MAE']:.4f}")
    print(f"RMSE: {test_result['RMSE']:.4f}")
    print(f"SMAPE: {test_result['SMAPE']:.4f}%")
    
    # Save the trained model for the current fold
    model.save(os.path.join(model_save_dir, f"fold_{fold+1}.keras"))
    print(f"Model for Fold {fold+1} saved.")

    # Clean up to free memory
    del model, train_generator, val_generator, early_stopping, history
    tf.keras.backend.clear_session()
    gc.collect()

# ==================== 8. Print Final Averages ==================== #

# Calculate average metrics
avg_fold_metrics = pd.DataFrame(fold_metrics).mean().to_dict()
avg_test_metrics = pd.DataFrame(test_metrics).mean().to_dict()

print("\n" + "="*80)
print("Final Average Metrics Across All Folds:")
print("="*80)

print("\nAverage Validation Metrics:")
print(f"Average R²: {avg_fold_metrics['R2']:.4f}")
print(f"Average MAE: {avg_fold_metrics['MAE']:.4f}")
print(f"Average RMSE: {avg_fold_metrics['RMSE']:.4f}")
print(f"Average SMAPE: {avg_fold_metrics['SMAPE']:.4f}%")

print("\nAverage Independent Test Set Metrics:")
print(f"Average R²: {avg_test_metrics['R2']:.4f}")
print(f"Average MAE: {avg_test_metrics['MAE']:.4f}")
print(f"Average RMSE: {avg_test_metrics['RMSE']:.4f}")
print(f"Average SMAPE: {avg_test_metrics['SMAPE']:.4f}%")

# Save the printed output to a text file
sys.stdout = old_stdout
printed_output = captured_output.getvalue()
output_path = os.path.join("dual_attention_analysis", "kfold_analysis_output.txt")
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w") as f:
    f.write(printed_output)

print(f"\nFull analysis report saved to: {output_path}")

Using 26 raster layers for CNN input.
  - bui.tif
  - ndsi.tif
  - savi.tif
  - ndbsi.tif
  - ui.tif
  - ndwi.tif
  - ndbi.tif
  - awei.tif
  - evi.tif
  - mndwi.tif
  - ndvi.tif
  - LULC2020.tif
  - LULC2021.tif
  - LULC2022.tif
  - LULC2019.tif
  - LULC2018.tif
  - LULC2017.tif
  - Pb_R.tif
  - ClayR.tif
  - SandR.tif
  - CdR.tif
  - CrR.tif
  - AsR.tif
  - SiltR.tif
  - CuR.tif
  - NiR.tif

Full analysis report saved to: dual_attention_analysis/kfold_analysis_output.txt
