In [None]:
import tensorflow as tf
import os
import random
from keras_unet_collection import models, base, utils
import numpy as np
from tqdm import tqdm 
from skimage import io  
from keras import backend as K
from skimage.io import imread, imshow
from skimage.transform import resize
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, concatenate, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, LearningRateScheduler, EarlyStopping
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.losses import *
import seaborn as sns
from sklearn.base import BaseEstimator, RegressorMixin
import pandas as pd
from datetime import datetime
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional
import logging

2024-12-11 22:18:25.182322: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-11 22:18:25.204960: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-11 22:18:25.204987: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-11 22:18:25.205610: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-11 22:18:25.209765: I tensorflow/core/platform/cpu_feature_guar

In [None]:
def load_remote_sensing_data(desired_num_patches, random_seed):
    """
    Load and preprocess remote sensing data with random patch selection.
    
    Args:
        desired_num_patches (int): Number of patches to randomly select
        random_seed (int): Random seed for reproducibility
    """
    # Set random seed for reproducibility
    np.random.seed(random_seed)
    
    # Your data directories
    optic_images_dir = 'Sentinel 2_Patches'
    sar_images_dir = 'Sentinel1_Patches'
    agb_maps_dir = 'AGB_GroundTruth_Patches'
    fvs_images_dir = 'Climatic_Parameters_Patches'
    slope_images_dir = 'UNET_Slope_Patches/Patches'
    aspect_images_dir = 'Topography_Aspect_Patches'
    landsat8_patches_dir = 'Topography_Slope_Patches'

    # Get the list of file names in the directories
    agb_map_files = sorted(os.listdir(agb_maps_dir))
    optic_image_files = sorted(os.listdir(optic_images_dir))
    sar_image_files = sorted(os.listdir(sar_images_dir))
    fvs_image_files = sorted(os.listdir(fvs_images_dir))
    slope_image_files = sorted(os.listdir(slope_images_dir))
    aspect_image_files = sorted(os.listdir(aspect_images_dir))
    landsat8_image_files = sorted(os.listdir(landsat8_patches_dir))

    # Verify all directories have the same number of files
    total_patches = len(agb_map_files)
    assert all(len(files) == total_patches for files in [
        optic_image_files, sar_image_files, fvs_image_files,
        slope_image_files, aspect_image_files, landsat8_image_files
    ]), "All directories must have the same number of files"

    print(f"Total available patches: {total_patches}")
    
    # Randomly select indices
    selected_indices = np.random.choice(
        total_patches, 
        size=min(desired_num_patches, total_patches), 
        replace=False
    )
    selected_indices = sorted(selected_indices)  # Sort for consistent file loading
    
    # Initialize empty lists to store data
    x_train_list = []
    agb_values_list = []

    # Loop through selected patches
    for idx in tqdm(selected_indices, desc='Loading Patches'):
        try:
            # Load images
            optic_image = io.imread(os.path.join(optic_images_dir, optic_image_files[idx]))
            sar_image = io.imread(os.path.join(sar_images_dir, sar_image_files[idx]))
            agb_map = io.imread(os.path.join(agb_maps_dir, agb_map_files[idx]))
            fvs_image = io.imread(os.path.join(fvs_images_dir, fvs_image_files[idx]))
            slope_image = io.imread(os.path.join(slope_images_dir, slope_image_files[idx]))
            aspect_image = io.imread(os.path.join(aspect_images_dir, aspect_image_files[idx]))
            landsat8_image = io.imread(os.path.join(landsat8_patches_dir, landsat8_image_files[idx]))

            # Verify Landsat 8 bands
            if landsat8_image.shape[-1] != 7:
                print(f"Skipping patch {idx}: Invalid Landsat 8 bands {landsat8_image.shape[-1]}")
                continue

            # Normalize all datasets using min-max scaling
            def normalize(img):
                return (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)

            optic_image_normalized = normalize(optic_image)
            sar_image_normalized = normalize(sar_image)
            fvs_image_normalized = normalize(fvs_image)
            slope_image_normalized = normalize(slope_image)
            aspect_image_normalized = normalize(aspect_image)
            landsat8_image_normalized = normalize(landsat8_image)

            # Expand dimensions for single-channel images
            slope_image_expanded = np.expand_dims(slope_image_normalized, axis=-1)
            aspect_image_expanded = np.expand_dims(aspect_image_normalized, axis=-1)

            # Concatenate all images
            x_train = np.concatenate([
                optic_image_normalized,
                sar_image_normalized,
                fvs_image_normalized,
                slope_image_expanded,
                aspect_image_expanded,
                landsat8_image_normalized
            ], axis=-1)

            # Store the concatenated data
            x_train_list.append(x_train)
            agb_values_list.append(agb_map)

        except Exception as e:
            print(f"Error processing patch {idx}: {str(e)}")
            continue

    # Combine all patches into single arrays
    x_data = np.stack(x_train_list)
    agb_values = np.stack(agb_values_list)

    # Normalize AGB values
    agb_scaled = normalize(agb_values)
    agb_normalized = np.expand_dims(agb_scaled, axis=-1)

    # Split the data into training, validation, and test sets
    x_train, x_test, y_train, y_test = train_test_split(
        x_data, agb_normalized, test_size=0.20, random_state=random_seed
    )
    x_train, x_val, y_train, y_val = train_test_split(
        x_train, y_train, test_size=0.20, random_state=random_seed
    )

    # Print dataset information
    print('\nDataset shapes:')
    print('x_train:', x_train.shape)
    print('y_train:', y_train.shape)
    print('x_val:', x_val.shape)
    print('y_val:', y_val.shape)
    print('x_test:', x_test.shape)
    print('y_test:', y_test.shape)

    return {
        'x_train': x_train,
        'y_train': y_train,
        'x_val': x_val,
        'y_val': y_val,
        'x_test': x_test,
        'y_test': y_test,
        'input_shape': x_train.shape[1:]
    }

# Usage example:
if __name__ == "__main__":
    # Load the data with random sampling
    data = load_remote_sensing_data(
        desired_num_patches=400,
        random_seed=30
    )

Total available patches: 1681


Loading Patches: 100%|██████████| 400/400 [00:01<00:00, 279.57it/s]



Dataset shapes:
x_train: (256, 128, 128, 31)
y_train: (256, 128, 128, 1)
x_val: (64, 128, 128, 31)
y_val: (64, 128, 128, 1)
x_test: (80, 128, 128, 31)
y_test: (80, 128, 128, 1)


In [3]:
from keras.utils import custom_object_scope
from keras.models import load_model
from keras_unet_collection.transformer_layers import patch_extract, patch_embedding, patch_merging, patch_expanding, drop_path,WindowAttention, SwinTransformerBlock  # Assuming these are part of the custom layers

# Include all custom objects that your model uses
custom_objects = {
    'patch_extract': patch_extract,
    'patch_embedding': patch_embedding,
    'patch_merging': patch_merging,
    'patch_expanding': patch_expanding,
    'drop_path': drop_path,
    'WindowAttention': WindowAttention,
    'SwinTransformerBlock': SwinTransformerBlock

    # Add other custom layers or utilities as needed
}

# Load the model using the custom scope
with custom_object_scope(custom_objects):
    
  model = load_model('result/Keras-Collecetion_Att_UNET/AGB_Att_UNET/31bands/model_for_AttentionUNET.h5')

2024-12-11 22:09:58.426937: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-12-11 22:09:58.448145: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-12-11 22:09:58.448275: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

In [None]:
# First, define the band names
def create_band_names() -> List[str]:
    """
    Creates descriptive names for all 31 bands in the remote sensing dataset.
    """
    # Optical bands (13 bands)
    Sentinel2_bands = [
        "Sentinel2_Coastal-Aerosol",           
        "Sentinel2_Blue",          
        "Sentinel2_Green",           
        "Sentinel2_Red",      
        "Sentinel2_Red Edge 1",      
        "Sentinel2_Red Edge 2",      
        "Sentinel2_Red Edge 3",           
        "Sentinel2_NIR",          
        "Sentinel2_Red Edge 4",         
        "Sentinel2_Water vapor",         
        "Sentinel2_Cirrus",          
        "Sentinel2_SWIR 1",           
        "Sentinel2_SWIR 2"           
    ]
    
    # SAR bands (2 bands)
    sar_bands = [
        "SAR_VH",                # Vertical-Vertical polarization
        "SAR_VV"                 # Vertical-Horizontal polarization
    ]
    
    # Forest Vegetation Simulator (FVS) bands (7 bands)
    fvs_bands = [
        "FVS_ffp",        # Average tree height
        "FVS_sday",        # Crown cover percentage
        "FVS_gsp",         # Basal area
        "FVS_mmax",           # Tree density
        "FVS_mmin",           # Total biomass
        "FVS_mat",      # Canopy height
        "FVS_map"           # Stand age
    ]
    
    # Topographic bands (2 bands)
    topo_bands = [
        "Topo_Slope",            # Terrain slope
        "Topo_Aspect"            # Terrain aspect
    ]
    
    # Landsat 8 bands (7 bands)
    landsat_bands = [
        "Landsat8_Coastal",      # Coastal/Aerosol band
        "Landsat8_Blue",         # Blue band
        "Landsat8_Green",        # Green band
        "Landsat8_Red",          # Red band
        "Landsat8_NIR",          # Near-infrared band
        "Landsat8_SWIR1",        # Short-wave infrared 1
        "Landsat8_SWIR2"         # Short-wave infrared 2
    ]
    
    # Combine all bands
    all_bands = (
        Sentinel2_bands +
        sar_bands +
        fvs_bands +
        topo_bands +
        landsat_bands
    )
    
    assert len(all_bands) == 31, f"Expected 31 bands, got {len(all_bands)}"
    return all_bands

# Wrapper for TensorFlow models
class TFModelWrapper(BaseEstimator, RegressorMixin):
    """Wrapper for TensorFlow models to make them scikit-learn compatible."""
    
    def __init__(self, model):
        self.model = model
    
    def fit(self, X, y):
        return self
    
    def predict(self, X):
        return np.array(self.model.predict(X, verbose=0))

# Main Feature Importance Analyzer class
class FeatureImportanceAnalyzer:
    """Advanced feature importance analyzer for remote sensing models."""
    
    def __init__(self, model, output_dir='feature_importance_results', batch_size=1):
        self.model = model
        self.feature_names = create_band_names()
        self.output_dir = output_dir
        self.batch_size = batch_size
        self.results = {}
        self.feature_groups = self._create_feature_groups()
        
        os.makedirs(output_dir, exist_ok=True)
        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            filename=os.path.join(output_dir, 'analysis.log')
        )
    
    def _create_feature_groups(self):
        """Creates groups of features based on their source."""
        groups = {
            'Sentinel2': [],
            'SAR': [],
            'FVS': [],
            'Topographic': [],
            'Landsat8': []
        }
        
        for i, name in enumerate(self.feature_names):
            if name.startswith('Sentinel2'):
                groups['Sentinel2'].append(i)
            elif name.startswith('SAR'):
                groups['SAR'].append(i)
            elif name.startswith('FVS'):
                groups['FVS'].append(i)
            elif name.startswith('Topo'):
                groups['Topographic'].append(i)
            elif name.startswith('Landsat8'):
                groups['Landsat8'].append(i)
                
        return groups

    def run_permutation_importance(self, X, y, n_repeats=5):
        """Compute permutation importance."""
        logging.info("Starting permutation importance analysis...")
        print("Computing permutation importance...")
        
        wrapped_model = TFModelWrapper(self.model)
        baseline_pred = wrapped_model.predict(X)
        baseline_score = np.mean((baseline_pred - y) ** 2)
        
        importances = []
        importance_std = []
        
        for feature_idx in tqdm(range(X.shape[-1]), desc="Analyzing features"):
            feature_importances = []
            for _ in range(n_repeats):
                X_permuted = X.copy()
                permuted_idx = np.random.permutation(len(X))
                X_permuted[..., feature_idx] = X_permuted[permuted_idx, ..., feature_idx]
                
                permuted_pred = wrapped_model.predict(X_permuted)
                permuted_score = np.mean((permuted_pred - y) ** 2)
                importance = permuted_score - baseline_score
                feature_importances.append(importance)
            
            importances.append(np.mean(feature_importances))
            importance_std.append(np.std(feature_importances))
        
        self.results['permutation'] = {
            'mean': np.array(importances),
            'std': np.array(importance_std)
        }
        return self.results['permutation']

    def run_occlusion_sensitivity(self, X, patch_size=128):
        """Compute occlusion sensitivity."""
        logging.info("Starting occlusion sensitivity analysis...")
        print("Computing occlusion sensitivity...")
        
        importances = np.zeros(len(self.feature_names))
        importance_std = np.zeros(len(self.feature_names))
        baseline = np.array(self.model.predict(X, verbose=0))
        
        for i in tqdm(range(X.shape[-1]), desc="Analyzing features"):
            impacts = []
            for start_row in range(0, X.shape[1], patch_size):
                for start_col in range(0, X.shape[2], patch_size):
                    X_occluded = X.copy()
                    end_row = min(start_row + patch_size, X.shape[1])
                    end_col = min(start_col + patch_size, X.shape[2])
                    X_occluded[:, start_row:end_row, start_col:end_col, i] = 0
                    
                    prediction = np.array(self.model.predict(X_occluded, verbose=0))
                    impact = np.mean(np.abs(baseline - prediction))
                    impacts.append(impact)
            
            importances[i] = np.mean(impacts)
            importance_std[i] = np.std(impacts)
        
        self.results['occlusion'] = {
            'mean': importances,
            'std': importance_std
        }
        return self.results['occlusion']

    def run_gradient_importance(self, X):
        """Compute gradient-based importance."""
        logging.info("Starting gradient-based importance analysis...")
        print("Computing gradient-based importance...")
        
        importances = []
        n_batches = int(np.ceil(len(X) / self.batch_size))
        
        for batch_idx in tqdm(range(n_batches), desc="Processing batches"):
            start_idx = batch_idx * self.batch_size
            end_idx = min((batch_idx + 1) * self.batch_size, len(X))
            X_batch = X[start_idx:end_idx]
            
            with tf.GradientTape() as tape:
                X_tf = tf.convert_to_tensor(X_batch, dtype=tf.float32)
                tape.watch(X_tf)
                predictions = self.model(X_tf)
            
            gradients = tape.gradient(predictions, X_tf)
            batch_importances = np.mean(np.abs(gradients.numpy()), axis=(0, 1, 2))
            importances.append(batch_importances)
        
        importances = np.array(importances)
        self.results['gradient'] = {
            'mean': np.mean(importances, axis=0),
            'std': np.std(importances, axis=0)
        }
        return self.results['gradient']

    def plot_feature_importance(self, method, grouped=False, figsize=(15, 10), color_palette='viridis'):
        """Plot feature importance results."""
        if method not in self.results:
            raise ValueError(f"Results for method '{method}' not found. Run analysis first.")
        
        plt.figure(figsize=figsize)
        sns.set_style("whitegrid")
        
        if grouped:
            grouped_mean = {group: np.mean(self.results[method]['mean'][indices]) 
                          for group, indices in self.feature_groups.items()}
            grouped_std = {group: np.mean(self.results[method]['std'][indices]) 
                         for group, indices in self.feature_groups.items()}
            
            groups = list(grouped_mean.keys())
            means = list(grouped_mean.values())
            stds = list(grouped_std.values())
            
            colors = sns.color_palette(color_palette, n_colors=len(groups))
            plt.bar(range(len(groups)), means, yerr=stds, capsize=5, color=colors)
            plt.xticks(range(len(groups)), groups, rotation=45, ha='right')
        else:
            colors = sns.color_palette(color_palette, n_colors=len(self.feature_names))
            plt.bar(range(len(self.feature_names)),
                   self.results[method]['mean'],
                   yerr=self.results[method]['std'],
                   capsize=5,
                   color=colors)
            plt.xticks(range(len(self.feature_names)),
                      self.feature_names,
                      rotation=90,
                      ha='right')
        
        plt.title(f"{method.capitalize()} Feature Importance Analysis")
        plt.xlabel("Features")
        plt.ylabel("Importance Score")
        plt.tight_layout()
        
        plt.savefig(
            os.path.join(self.output_dir, f'{method}_importance{"_grouped" if grouped else ""}.png'),
            dpi=300,
            bbox_inches='tight'
        )
        plt.close()

    def save_results(self):
        """Save analysis results to CSV."""
        results_dict = {
            'Feature': self.feature_names,
            'Category': [name.split('_')[0] for name in self.feature_names]
        }
        
        for method, data in self.results.items():
            results_dict[f"{method}_importance"] = data['mean']
            results_dict[f"{method}_std"] = data['std']
        
        df = pd.DataFrame(results_dict)
        
        # Add summary statistics by category
        summary_df = df.groupby('Category').agg({
            col: ['mean', 'std'] for col in df.columns 
            if col.endswith('_importance')
        })
        
        # Save both detailed and summary results
        df.to_csv(os.path.join(self.output_dir, 'feature_importance_detailed.csv'), index=False)
        summary_df.to_csv(os.path.join(self.output_dir, 'feature_importance_summary.csv'))
        
        return df

def analyze_feature_importance(model, X_test, y_test):
    """Main function to run all feature importance analyses."""
    # Create analyzer instance
    analyzer = FeatureImportanceAnalyzer(model)
    
    # Run all analyses
    analyzer.run_permutation_importance(X_test, y_test)
    analyzer.run_occlusion_sensitivity(X_test)
    analyzer.run_gradient_importance(X_test)
    
    # Generate visualizations
    for method in ['permutation', 'occlusion', 'gradient']:
        analyzer.plot_feature_importance(method, grouped=False)
        analyzer.plot_feature_importance(method, grouped=True)
    
    # Save results
    results_df = analyzer.save_results()
    return analyzer, results_df
    
# Example usage
if __name__ == "__main__":

# Example usage
    # Load your model and data
    from keras.utils import custom_object_scope
    from keras.models import load_model
    from keras_unet_collection.transformer_layers import (
        patch_extract, patch_embedding, patch_merging, patch_expanding,
        drop_path, WindowAttention, SwinTransformerBlock
    )
    
    # Custom objects dictionary also for Transformer models
    custom_objects = {
        'patch_extract': patch_extract,
        'patch_embedding': patch_embedding,
        'patch_merging': patch_merging,
        'patch_expanding': patch_expanding,
        'drop_path': drop_path,
        'WindowAttention': WindowAttention,
        'SwinTransformerBlock': SwinTransformerBlock
    }
    
    # Load the model
    with custom_object_scope(custom_objects):
        model = load_model('Model.h5')
    X_test = data['x_test']
    y_test = data['y_test']
    analyzer, results = analyze_feature_importance(model, X_test, y_test)

2024-12-11 22:19:27.534360: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-12-11 22:19:27.554748: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-12-11 22:19:27.554886: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:901] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-

Computing permutation importance...


2024-12-11 22:19:28.657286: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8901
2024-12-11 22:19:28.760161: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
2024-12-11 22:19:30.085842: I external/local_tsl/tsl/platform/default/subprocess.cc:304] Start cannot spawn child process: No such file or directory
Analyzing features: 100%|██████████| 31/31 [01:04<00:00,  2.10s/it]


Computing occlusion sensitivity...


Analyzing features: 100%|██████████| 31/31 [00:12<00:00,  2.44it/s]


Computing gradient-based importance...


Processing batches:   0%|          | 0/80 [00:00<?, ?it/s]2024-12-11 22:20:52.496684: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.11GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-12-11 22:20:52.532821: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.11GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-12-11 22:20:52.587052: W external/local_tsl/tsl/framework/bfc_allocator.cc:296] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.08GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-12-1