In [2]:
from pca_dataflow_V2_variables import *

In [None]:
import numpy as np
from osgeo import gdal, osr
from scipy.ndimage import generic_filter


def calculate_shannon_diversity(window):
    """
    Calculate Shannon diversity index for a given window of cluster values.
    """
    # Remove no-data values
    valid_values = window[window != -9999]
    if len(valid_values) == 0:
        return -9999  # No valid data in the window

    # Compute the frequency of each cluster
    unique, counts = np.unique(valid_values, return_counts=True)
    probabilities = counts / counts.sum()

    # Calculate Shannon diversity index
    shannon_diversity = -np.sum(probabilities * np.log(probabilities))
    return shannon_diversity


def create_shannon_diversity_map(input_file, output_file, window_size):
    """
    Generate a Shannon diversity map from a GeoTIFF containing cluster data.

    Args:
        input_file (str): Path to the input GeoTIFF file.
        output_file (str): Path to save the output GeoTIFF file.
        window_size (int): Window size for computing Shannon diversity.
    """
    # Open the input GeoTIFF
    dataset = gdal.Open(input_file)
    band = dataset.GetRasterBand(1)
    input_array = band.ReadAsArray()

    # Get geo-information from the input dataset
    geotransform = dataset.GetGeoTransform()
    projection = dataset.GetProjection()
    no_data_value = band.GetNoDataValue()

    # Apply the Shannon diversity calculation using a sliding window
    diversity_map = generic_filter(
        input_array,
        function=calculate_shannon_diversity,
        size=(window_size, window_size),
        mode='constant',
        cval=no_data_value
    )

    # Save the Shannon diversity map as a GeoTIFF
    driver = gdal.GetDriverByName("GTiff")
    out_dataset = driver.Create(
        output_file,
        dataset.RasterXSize,
        dataset.RasterYSize,
        1,
        gdal.GDT_Float32
    )
    out_dataset.SetGeoTransform(geotransform)
    out_dataset.SetProjection(projection)

    # Write the Shannon diversity data and set no-data value
    out_band = out_dataset.GetRasterBand(1)
    out_band.WriteArray(diversity_map)
    out_band.SetNoDataValue(-9999)

    # Save and close datasets
    out_band.FlushCache()
    out_dataset = None
    dataset = None
    print(f"Shannon diversity map saved to {output_file}")


# Example Usage
input_dir = "/Volumes/T9/new_pca_test/c_ang20180812t223939rfl/data/hs_raw_image" 
input_file = os.path.join(input_dir, "spectral_species.tif")
output_file = os.path.join(input_dir, "shannon_diversity_map.tif")  # Replace with desired output file path
window_size = 10  # Example window size

create_shannon_diversity_map(input_file, output_file, window_size)


In [3]:
import numpy as np
from osgeo import gdal
from scipy.spatial.distance import pdist, squareform
from tqdm import tqdm

def calculate_beta_diversity(cluster_image, window_size, no_data_value=-9999):
    """
    Calculate a beta diversity map (e.g., Jaccard index) for a cluster image using a sliding window.

    Args:
        cluster_image (numpy.ndarray): Input 2D array of cluster labels.
        window_size (int): Size of the moving window (must be odd).
        no_data_value (int): Value representing no data in the input image.

    Returns:
        numpy.ndarray: Beta diversity map.
    """
    rows, cols = cluster_image.shape
    half_window = window_size // 2
    beta_div_map = np.full((rows, cols), no_data_value, dtype=float)

    # Pad the input image to handle edge cases
    padded_image = np.pad(cluster_image, pad_width=half_window, mode='constant', constant_values=no_data_value)

    # Sliding window computation
    for i in tqdm(range(rows)):
        for j in range(cols):
            # Extract the local window
            local_window = padded_image[i:i + window_size, j:j + window_size]
            
            # Mask no_data_values
            local_window = local_window[local_window != no_data_value]
            if len(local_window) <= 1:
                beta_div_map[i, j] = no_data_value  # Assign no_data_value if the window is empty or has insufficient data
                continue
            
            # Compute pairwise Jaccard distance
            try:
                pairwise_distances = pdist(np.atleast_2d(local_window), metric='jaccard')
                beta_div_map[i, j] = np.mean(pairwise_distances)  # Average beta diversity for the window
            except Exception as e:
                # Handle any computation errors gracefully
                beta_div_map[i, j] = no_data_value

    return beta_div_map

def save_beta_diversity_map(output_path, beta_div_map, gdal_dataset, no_data_value=-9999, data_type=gdal.GDT_Float32):
    """
    Save the beta diversity map as a GeoTIFF file.

    Args:
        output_path (str): Output file path.
        beta_div_map (numpy.ndarray): Beta diversity map.
        gdal_dataset (gdal.Dataset): Input dataset for georeferencing.
        no_data_value (int): Value representing no data in the output file.
        data_type (GDAL data type): Data type for the output file.
    """
    driver = gdal.GetDriverByName('GTiff')
    rows, cols = beta_div_map.shape
    out_raster = driver.Create(output_path, cols, rows, 1, data_type)
    
    # Set georeference and projection
    out_raster.SetGeoTransform(gdal_dataset.GetGeoTransform())
    out_raster.SetProjection(gdal_dataset.GetProjection())
    
    # Write the data
    out_band = out_raster.GetRasterBand(1)
    out_band.WriteArray(beta_div_map)
    out_band.SetNoDataValue(no_data_value)
    out_raster.FlushCache()

# Example usage
input_dir = "/Volumes/T9/new_pca_test/c_ang20180812t223939rfl/data/hs_raw_image" 
input_file = os.path.join(input_dir, "spectral_species.tif")
output_file = os.path.join(input_dir, "beta_diversity_map.tif")  # Replace with desired output file path
window_size = 10

# Open the input cluster map
dataset = gdal.Open(input_file)
cluster_image = dataset.ReadAsArray()

# Calculate beta diversity map
beta_div_map = calculate_beta_diversity(cluster_image, window_size)

# Save the beta diversity map
save_beta_diversity_map(output_file, beta_div_map, dataset)


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
100%|██████████| 9404/9404 [00:39<00:00, 238.46it/s]
