## pip install pandas numpy matplotlib scipy earthaccess netCDF4 h5py pyproj opencv-python pillow jupyter
#### HABNet Image Generator

In [None]:
import numpy as np
import h5py
from pathlib import Path
import cv2
from scipy.interpolate import griddata
from scipy.spatial import ConvexHull
import matplotlib.pyplot as plt
from PIL import Image
import warnings

warnings.filterwarnings('ignore')

In [None]:
# converts h5 datacubes to png images
class HABNetImageGenerator:
    def __init__(self, input_res=2000, output_res=1000, alpha_size=2):
        self.input_res = input_res
        self.output_res = output_res
        self.alpha_size = alpha_size
        self.fract = output_res / input_res  # e.g., 1000/2000 = 0.5
        
    # convert UTM coordinates to grid coordinates [0,50]    
    def convert_utm_to_grid_coords(self, utm_x, utm_y, center_utm_x, center_utm_y):
        
        # spatial extent in m
        extent_m = 100000  # 100km in meters
        resolution_m = 2000  # 2km in meters
        
        # convert from UTM to relative coordinates
        rel_x = utm_x - center_utm_x
        rel_y = utm_y - center_utm_y
        
        # convert to grid coordinates [0, 50]
        # center of grid is 25, extent is +/- 25
        grid_x = (rel_x / resolution_m) + 25
        grid_y = (rel_y / resolution_m) + 25
        
        return grid_x, grid_y

    # output interpolation grid
    def setup_output_grid(self, input_range_x, input_range_y):
        
        x_coords = np.arange(
            input_range_x[0] + self.fract/2, 
            input_range_x[1], 
            self.fract
        )
        y_coords = np.arange(
            input_range_y[0] + self.fract/2, 
            input_range_y[1], 
            self.fract
        )
        
        xq, yq = np.meshgrid(x_coords, y_coords)
        return xq, yq

    # grid data to outpu
    def get_image(self, output_xq, output_yq, input_xp, input_yp, input_up):
    
        if len(input_xp) < 4:  # need minimum points for interpolation
            return np.full(output_xq.shape, np.nan)
        
        try:
            # griddata interpolation
            output_image = griddata(
                np.column_stack([input_xp, input_yp]), 
                input_up,
                np.column_stack([output_xq.ravel(), output_yq.ravel()]),
                method='linear',
                fill_value=np.nan
            ).reshape(output_xq.shape)
            
            # fill remaining nans with nearest neighbor
            if np.any(np.isnan(output_image)):
                nn_image = griddata(
                    np.column_stack([input_xp, input_yp]), 
                    input_up,
                    np.column_stack([output_xq.ravel(), output_yq.ravel()]),
                    method='nearest',
                    fill_value=np.nan
                ).reshape(output_xq.shape)
                
                nan_mask = np.isnan(output_image)
                output_image[nan_mask] = nn_image[nan_mask]
                
            return output_image
            
        except Exception as e:
            print(f"gridding failed: {e}")
            return np.full(output_xq.shape, np.nan)

    # safe normalization
    def normalize(self, image, min_val, max_val):
        
        # check for all nans
        if not np.any(np.isfinite(image)):
            return np.zeros(image.shape, dtype=np.uint8)
        
        # use actual data range if range is invalid
        actual_min = np.nanmin(image)
        actual_max = np.nanmax(image)
        
        print(f"    Data range: {actual_min:.6f} to {actual_max:.6f}")
        print(f"    Norm range: {min_val:.6f} to {max_val:.6f}")
        
        # if min is  max, use data range
        if abs(max_val - min_val) < 1e-10:
            print(f"    Using actual range")
            min_val = actual_min
            max_val = actual_max
            
            # if still equal, return middle gray
            if abs(max_val - min_val) < 1e-10:
                return np.full(image.shape, 128, dtype=np.uint8)
        
        # normalize
        normalized = (image - min_val) / (max_val - min_val)
        normalized = np.clip(normalized, 0, 1)  # 0-1 range
        quantized = np.round(255.0 * normalized)
        quantized = np.nan_to_num(quantized, nan=0)  # nans become 0 (black)
        
        unique_vals = len(np.unique(quantized[quantized > 0]))
        print(f"    Output: {unique_vals} unique non-zero values")
        
        return quantized.astype(np.uint8)

    # process single h5 datacube to png images
    def process_datacube_to_images(self, h5_file_path, output_base_dir, group_min_max=None):
       
        h5_file_path = Path(h5_file_path)
        output_base_dir = Path(output_base_dir)
        
        print(f"Processing datacube: {h5_file_path.name}")
        
        with h5py.File(h5_file_path, 'r') as h5f:
            # get modalitys
            if 'Modnames' in h5f:
                mod_names = [name.decode('utf-8') if isinstance(name, bytes) else name 
                           for name in h5f['Modnames'][:]]
            else:
                mod_names = [name for name in h5f.keys() if name != 'GroundTruth']
            
            print(f"Found {len(mod_names)} modalities: {mod_names}")
            
            # get center coordinates from ground truth
            gt_group = h5f['GroundTruth']
            center_lat = gt_group.attrs['thisLat']
            center_lon = gt_group.attrs['thisLon']
            print(f"Event center: ({center_lat:.4f}, {center_lon:.4f})")
            
            # setup output grid 
            input_range_x = [0, 50]
            input_range_y = [0, 50]
            output_xq, output_yq = self.setup_output_grid(input_range_x, input_range_y)
            print(f"Output image size: {output_xq.shape}")
            
            # get number of days from first modality
            first_mod = mod_names[0]
            datacube_shape = h5f[first_mod]['Ims'].shape
            number_of_days = datacube_shape[2] if len(datacube_shape) == 3 else 1
            print(f"Processing {number_of_days} days")
            
            # compute min/max value
            if group_min_max is None:
                print("Computing min/max values from data...")
                group_min_max = self.get_min_max_from_datacube(h5f, mod_names)
            
            # estimate center UTM from first modality
            center_utm_x = center_utm_y = None
            for mod_name in mod_names:
                if mod_name in h5f and 'PointsProj' in h5f[mod_name]:
                    points_proj = h5f[mod_name]['PointsProj'][:]
                    if len(points_proj) > 0:
                        center_utm_x = np.mean(points_proj[:, 0])
                        center_utm_y = np.mean(points_proj[:, 1])
                        print(f"Estimated center UTM: ({center_utm_x:.1f}, {center_utm_y:.1f})")
                        break
            
            if center_utm_x is None:
                print("ERROR: Could not determine center UTM coordinates!")
                return
            
            # process each modality
            for mod_idx, mod_name in enumerate(mod_names):
                print(f"\n  Modality {mod_idx+1}/{len(mod_names)}: {mod_name}")
                
                # create output directory for this modality
                mod_output_dir = output_base_dir / str(mod_idx + 1)
                mod_output_dir.mkdir(parents=True, exist_ok=True)
                
                if mod_name not in h5f:
                    print(f"    Not found in file")
                    # create empty images
                    for day in range(1, number_of_days + 1):
                        empty_image = np.zeros(output_xq.shape, dtype=np.uint8)
                        Image.fromarray(empty_image).save(mod_output_dir / f"{day:02d}.png")
                    continue
                
                mod_group = h5f[mod_name]
                
                # get min/max for this modality
                if mod_idx < len(group_min_max):
                    this_min, this_max = group_min_max[mod_idx]
                else:
                    # fallback
                    if 'PointsProj' in mod_group:
                        points_proj = mod_group['PointsProj'][:]
                        if len(points_proj) > 0:
                            this_min = np.nanmin(points_proj[:, 2])
                            this_max = np.nanmax(points_proj[:, 2])
                        else:
                            this_min, this_max = 0, 1
                    else:
                        this_min, this_max = 0, 1
                
                print(f"    Min/Max: {this_min:.6f} / {this_max:.6f}")
                
                # get projected points data
                if 'PointsProj' in mod_group:
                    points_proj = mod_group['PointsProj'][:]
                    
                    # process each day
                    for day in range(1, number_of_days + 1):
                        try:
                            # filter points for this day: day 1 -> time 0
                            time_indices = (points_proj[:, 3] >= day-1) & (points_proj[:, 3] < day)
                            
                            if not np.any(time_indices):
                                print(f"      Day {day}: NO DATA")
                                output_image = np.zeros(output_xq.shape)
                            else:
                                day_points = points_proj[time_indices]
                                utm_x = day_points[:, 0]
                                utm_y = day_points[:, 1]
                                values = day_points[:, 2]
                                
                                # convert UTM to grid coordinates
                                grid_x, grid_y = self.convert_utm_to_grid_coords(
                                    utm_x, utm_y, center_utm_x, center_utm_y
                                )
                                
                                print(f"      Day {day}: {len(day_points)} points")
                                print(f"        UTM: X({np.min(utm_x):.1f}-{np.max(utm_x):.1f}) Y({np.min(utm_y):.1f}-{np.max(utm_y):.1f})")
                                print(f"        Grid: X({np.min(grid_x):.1f}-{np.max(grid_x):.1f}) Y({np.min(grid_y):.1f}-{np.max(grid_y):.1f})")
                                
                                # grid the data with converted coords
                                output_image = self.get_image(
                                    output_xq, output_yq, grid_x, grid_y, values
                                )
                            
                            # normalize
                            quantized_image = self.normalize(
                                output_image, this_min, this_max
                            )
                            
                            # save as png
                            output_file = mod_output_dir / f"{day:02d}.png"
                            Image.fromarray(quantized_image).save(output_file)
                            
                        except Exception as e:
                            print(f"        Error processing day {day}: {e}")
                            empty_image = np.zeros(output_xq.shape, dtype=np.uint8)
                            Image.fromarray(empty_image).save(mod_output_dir / f"{day:02d}.png")
                
                else:
                    print(f"    No PointsProj data found")
                    # create empty images
                    for day in range(1, number_of_days + 1):
                        empty_image = np.zeros(output_xq.shape, dtype=np.uint8)
                        Image.fromarray(empty_image).save(mod_output_dir / f"{day:02d}.png")
                
                print(f"    Saved {number_of_days} images to {mod_output_dir}")
        
        print(f"✓ Completed: {h5_file_path.name}")

    # get min/max values from datacube
    def get_min_max_from_datacube(self, h5f, mod_names):
        group_min_max = []
        
        for mod_name in mod_names:
            if mod_name in h5f and 'PointsProj' in h5f[mod_name]:
                points_proj = h5f[mod_name]['PointsProj'][:]
                if len(points_proj) > 0:
                    values = points_proj[:, 2]  # value column
                    min_val = np.nanmin(values)
                    max_val = np.nanmax(values)
                else:
                    min_val, max_val = 0, 1
            else:
                min_val, max_val = 0, 1
            
            group_min_max.append((min_val, max_val))
        
        return group_min_max

In [None]:
def convert_all_datacubes_to_images(datacube_dir, output_dir, use_global_minmax=False):
    datacube_dir = Path(datacube_dir)
    output_dir = Path(output_dir)
    
    # find all h5 files
    h5_files = list(datacube_dir.glob("*.h5"))
    if not h5_files:
        print(f"No H5 files found in {datacube_dir}")
        return
    
    print(f"Found {len(h5_files)} H5 datacube files")
    
    # compute global min/max
    global_min_max = None
    if use_global_minmax:
        pass
    
    # create image generator
    generator = HABNetImageGenerator(input_res=2000, output_res=1000, alpha_size=2)
    
    # process each datacube
    for idx, h5_file in enumerate(h5_files):
        print(f"\n{'='*60}")
        print(f"Processing datacube {idx+1}/{len(h5_files)}")
        
        # create output directory for this datacube
        datacube_output_dir = output_dir / h5_file.stem
        datacube_output_dir.mkdir(parents=True, exist_ok=True)
        
        try:
            generator.process_datacube_to_images(
                h5_file, 
                datacube_output_dir, 
                group_min_max=global_min_max
            )
            
        except Exception as e:
            print(f"✗ Failed to process {h5_file.name}: {e}")
            continue
    
    print(f"\n{'='*60}")
    print(f"Image conversion complete!")

In [None]:
if __name__ == "__main__":
    # test the fix
    datacube_directory = "habnet_datacube_data/processed_h5_datacubes"
    output_directory = "habnet_datacube_data/png_images"
    
    convert_all_datacubes_to_images(
        datacube_dir=datacube_directory,
        output_dir=output_directory,
        use_global_minmax=False 
    )