## pip install pandas numpy matplotlib scipy earthaccess netCDF4 h5py pyproj opencv-python pillow jupyter

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import warnings
import earthaccess
import netCDF4 as nc
import os
from pathlib import Path
from scipy.interpolate import griddata
from scipy.spatial import ConvexHull
import pickle
import pyproj
from pyproj import Transformer
import h5py

warnings.filterwarnings('ignore')

### Configs

In [None]:
DATACUBE_CONFIG = {
    'spatial_extent_km': 100,      #100km like the habnet
    'temporal_extent_days': 10, 
    'spatial_resolution_km': 2,    # 2km bins for datacube
    'temporal_resolution_days': 1 
}

# modis-aqua modalities for now
HABNET_MODIS_AQUA_MODALITIES = [
    'chlor_a',      
    'Rrs_412',     
    'Rrs_443',      
    'Rrs_488',      
    'Rrs_531',    
    'Rrs_555',      
    'par'           
]

# gulf of Mexico bounds
GULF_BOUNDS = {
    'lat_min': 24.0, 'lat_max': 30.5,
    'lon_min': -88.0, 'lon_max': -80.0
}

### Basic setup

In [None]:
def setup_directories():
    base_dir = Path("habnet_datacube_data")
    raw_dir = base_dir / "raw_modis_l2"
    processed_dir = base_dir / "processed_h5_datacubes"  # h5 format
    
    for directory in [base_dir, raw_dir, processed_dir]:
        directory.mkdir(exist_ok=True)
        
    return base_dir, raw_dir, processed_dir

# spatial bounds 100km
def calculate_spatial_bounds(lat, lon, extent_km=100):
    extent_deg = extent_km / 111.0  # 1 degree around 111 km
    return {
        'lat_min': lat - extent_deg/2,
        'lat_max': lat + extent_deg/2,
        'lon_min': lon - extent_deg/2,
        'lon_max': lon + extent_deg/2
    }


# temporal bounds 10 days total
def calculate_temporal_bounds(event_date, days_before=9):
    end_date = event_date + timedelta(days=1)  
    start_date = event_date - timedelta(days=days_before)
    return {'start_date': start_date, 'end_date': end_date}

# utm from coords (like habnet)
def get_utm_zone_from_coords(lat, lon):
    utm_zone = int((lon + 180) / 6) + 1
    if lat >= 0:
        epsg_code = f"326{utm_zone:02d}"  # northern hemisphere
    else:
        epsg_code = f"327{utm_zone:02d}"  # southern hemisphere
    return epsg_code

# utm projection (like habnet)
def setup_utm_projection(lat, lon):
    epsg_code = get_utm_zone_from_coords(lat, lon)
    transformer_to_utm = Transformer.from_crs("EPSG:4326", f"EPSG:{epsg_code}", always_xy=True)
    transformer_from_utm = Transformer.from_crs(f"EPSG:{epsg_code}", "EPSG:4326", always_xy=True)
    return transformer_to_utm, transformer_from_utm, epsg_code


def setup_nasa_earthdata():
    print("Setting up NASA Earthdata auth")
    try:
        auth = earthaccess.login()
        if auth:
            print(" - NASA Earthdata auth successful!")
            return True
        else:
            print(" - Auth failed")
            return False
    except Exception as e:
        print(f"  Authentication error: {e}")
        return False


### Loading and Filtering Events

In [None]:
# load ground truth data and filter
def load_and_filter_hab_events(csv_file='habsos_20240430.csv'):
    print("Loading ground truth data")
    
    df = pd.read_csv(csv_file)
    df['SAMPLE_DATE'] = pd.to_datetime(df['SAMPLE_DATE'])
    
    # filter for Karenia brevis
    kb_data = df[df['SPECIES'] == 'brevis'].copy()
    kb_clean = kb_data.dropna(subset=['LATITUDE', 'LONGITUDE', 'SAMPLE_DATE', 'CELLCOUNT']).copy()
    
    # positive HAB when Karenia brevis > 50,000 cells/L
    # negative HAB when Karenia brevis = 0 cells/L
    positive_events = kb_clean[kb_clean['CELLCOUNT'] > 50000].copy()
    positive_events['HAB_EVENT'] = 1
    
    negative_events = kb_clean[kb_clean['CELLCOUNT'] == 0].copy()
    negative_events['HAB_EVENT'] = 0
    
    hab_events = pd.concat([positive_events, negative_events], ignore_index=True)
    
    # gulf of Mexico area
    gulf_events = hab_events[
        (hab_events['LATITUDE'] >= GULF_BOUNDS['lat_min']) &
        (hab_events['LATITUDE'] <= GULF_BOUNDS['lat_max']) &
        (hab_events['LONGITUDE'] >= GULF_BOUNDS['lon_min']) &
        (hab_events['LONGITUDE'] <= GULF_BOUNDS['lon_max'])
    ].copy()
    
    # filter closer to MODIS date (2003-2018)
    modis_start = datetime(2003, 1, 1)
    modis_end = datetime(2018, 12, 31)
    final_events = gulf_events[
        (gulf_events['SAMPLE_DATE'] >= modis_start) &
        (gulf_events['SAMPLE_DATE'] <= modis_end)
    ].copy()
    
    # create a UID 
    final_events['STABLE_EVENT_ID'] = (
        final_events['SAMPLE_DATE'].dt.strftime('%Y%m%d') + '_' +
        final_events['LATITUDE'].round(4).astype(str) + '_' +
        final_events['LONGITUDE'].round(4).astype(str) + '_' +
        final_events['CELLCOUNT'].astype(int).astype(str)
    )
    
    # remove dupes based on ID
    final_events = final_events.drop_duplicates(subset=['STABLE_EVENT_ID']).copy()
    
    # sort by date
    final_events = final_events.sort_values('SAMPLE_DATE').reset_index(drop=True)
    
    print(f"Total HAB events: {len(final_events):,}")
    print(f"Positive: {len(final_events[final_events['HAB_EVENT'] == 1]):,}")
    print(f"Negative: {len(final_events[final_events['HAB_EVENT'] == 0]):,}")
    print(f"Dates range: {final_events['SAMPLE_DATE'].min()} to {final_events['SAMPLE_DATE'].max()}")
    
    return final_events

### Create sample events

In [None]:
def create_sample_events(events_df, n_events=10):
    print(f"\nCreating sample of {n_events} events")
    
    # try use recent years 2015-2018
    recent_events = events_df[events_df['SAMPLE_DATE'].dt.year >= 2015].copy()
    
    # class balance
    n_positive = min(n_events // 2, len(recent_events[recent_events['HAB_EVENT'] == 1]))
    n_negative = min(n_events // 2, len(recent_events[recent_events['HAB_EVENT'] == 0]))
    
    # random state for reruns
    positive_sample = recent_events[recent_events['HAB_EVENT'] == 1].sample(
        n=n_positive, random_state=42
    )
    negative_sample = recent_events[recent_events['HAB_EVENT'] == 0].sample(
        n=n_negative, random_state=42
    )
    
    mvp_sample = pd.concat([positive_sample, negative_sample], ignore_index=True)
    mvp_sample = mvp_sample.sort_values('SAMPLE_DATE').reset_index(drop=True)
    
    print(f"MVP sample created: {len(mvp_sample)} events")
    print(f"  Positive (HAB): {len(mvp_sample[mvp_sample['HAB_EVENT'] == 1])}")
    print(f"  Negative (No HAB): {len(mvp_sample[mvp_sample['HAB_EVENT'] == 0])}")
    print(f"  Date range: {mvp_sample['SAMPLE_DATE'].min()} to {mvp_sample['SAMPLE_DATE'].max()}")
    
    return mvp_sample

### Modis L2 search

In [None]:
# get MODIS L2 Ocean Color data
def search_modis_l2_data(date, spatial_bounds):
    bbox = (
        spatial_bounds['lon_min'], spatial_bounds['lat_min'],
        spatial_bounds['lon_max'], spatial_bounds['lat_max']
    )
    try:
        granules = earthaccess.search_data(
            short_name='MODISA_L2_OC',
            temporal=(date.strftime('%Y-%m-%d'), date.strftime('%Y-%m-%d')),
            bounding_box=bbox
        )
        return granules
    except Exception as e:
        print(f"Search error for {date}: {e}")
        return []

#  download satelite granule if not cached
def download_and_cache_granule(granule, raw_dir):
    granule_name = granule['umm']['GranuleUR']
    cached_file = raw_dir / f"{granule_name}.nc"
    
    if cached_file.exists():
        print(f"- using cached file: {cached_file.name[:50]}")
        return str(cached_file)
    
    try:
        print(f" - downloading new file")
        files = earthaccess.download([granule], local_path=str(raw_dir))
        if files and len(files) > 0:
            downloaded_file = Path(files[0])
            if downloaded_file != cached_file:
                downloaded_file.rename(cached_file)
            return str(cached_file)
    except Exception as e:
        print(f" - download failed: {e}")
    
    return None

# get the modalities
def extract_modality_from_l2(file_path, modality, spatial_bounds, transformer_to_utm, transformer_from_utm):
    try:
        with nc.Dataset(file_path, 'r') as ds:
            # check if required groups exist
            if 'geophysical_data' not in ds.groups or 'navigation_data' not in ds.groups:
                print(f" missing required groups in {Path(file_path).name}")
                return None

            geo_data = ds.groups['geophysical_data']
            nav_data = ds.groups['navigation_data']

            # get coords
            lats = nav_data.variables['latitude'][:]
            lons = nav_data.variables['longitude'][:]

            # spatial filtering 
            lat_mask = (lats >= spatial_bounds['lat_min']) & (lats <= spatial_bounds['lat_max'])
            lon_mask = (lons >= spatial_bounds['lon_min']) & (lons <= spatial_bounds['lon_max'])
            spatial_mask = lat_mask & lon_mask

            if not np.any(spatial_mask):
                print(f" - no data in spatial bounds")
                return None

            # get modality data
            if modality not in geo_data.variables:
                print(f" - {modality} not found in file")
                return None
                
            mod_data = geo_data.variables[modality][:]
            if mod_data.shape != lats.shape:
                print(f" - {modality} shape mismatch")
                return None
            
            # apply spatial mask
            lats_roi = lats[spatial_mask]
            lons_roi = lons[spatial_mask]
            vals_roi = mod_data[spatial_mask]
            
            # filter valid values 
            if hasattr(vals_roi, 'mask'):
                valid_mask = ~vals_roi.mask
            else:
                valid_mask = np.isfinite(vals_roi)

            # modality-specific filtering like habnet
            if modality == 'chlor_a':
                valid_mask = valid_mask & (vals_roi > 0) & (vals_roi < 1000) 
            elif modality.startswith('Rrs_'):
                valid_mask = valid_mask & (vals_roi > 0) & (vals_roi < 1.0)
            elif modality == 'par':
                valid_mask = valid_mask & (vals_roi > 0)

            if not np.any(valid_mask):
                print(f" - no valid {modality} data after filtering")
                return None
            
            # keep only valid data
            final_lats = lats_roi[valid_mask]
            final_lons = lons_roi[valid_mask]
            final_vals = vals_roi[valid_mask]
            
            # reproject to utm like habnet in getData.m
            utm_x, utm_y = transformer_to_utm.transform(final_lons, final_lats)
            
            print(f" - {modality}: {len(final_vals)} points, "
                  f"range {np.min(final_vals):.4f}-{np.max(final_vals):.4f}")

            return {
                'lats': final_lats,
                'lons': final_lons, 
                'utm_x': utm_x,
                'utm_y': utm_y,
                'values': final_vals
            }

    except Exception as e:
        print(f" - error with {Path(file_path).name}: {e}")
        return None


### Reprojecting to Grid

In [None]:
# reproject to regular grid
def reproject_to_regular_grid(utm_x, utm_y, values, center_utm_x, center_utm_y, config):
    # define the projected coordinate ROI 
    half_extent = config['spatial_extent_km'] * 1000 / 2  # convert km to meters
    west_utm = center_utm_x - half_extent
    east_utm = center_utm_x + half_extent  
    south_utm = center_utm_y - half_extent
    north_utm = center_utm_y + half_extent
    
    # create regular grid 
    grid_size = config['spatial_extent_km'] // config['spatial_resolution_km']  # 100/2 = 50
    resolution_m = config['spatial_resolution_km'] * 1000  # 2000m
    
    # habnet's affine transform
    x_coords = np.linspace(west_utm + resolution_m/2, east_utm - resolution_m/2, grid_size)
    y_coords = np.linspace(south_utm + resolution_m/2, north_utm - resolution_m/2, grid_size)
    target_x_grid, target_y_grid = np.meshgrid(x_coords, y_coords)
    
    # get source points for interpolation
    source_points = np.column_stack([utm_x, utm_y])
    target_points = np.column_stack([target_x_grid.ravel(), target_y_grid.ravel()])
    
    # need at least 4 points for triangulation
    if len(values) < 4:
        return np.full((grid_size, grid_size), np.nan)

    try:
        # interpolate to grid with linear 
        interpolated = griddata(
            source_points, values, target_points,
            method='linear', fill_value=np.nan
        )
        
        # fill leftover NaNs with nearest neighbor if possible
        if np.any(np.isnan(interpolated)) and len(values) >= 1:
            interpolated_nn = griddata(
                source_points, values, target_points,
                method='nearest', fill_value=np.nan
            )
            nan_mask = np.isnan(interpolated)
            interpolated[nan_mask] = interpolated_nn[nan_mask]

        gridded_data = interpolated.reshape(target_x_grid.shape)
        return gridded_data
        
    except Exception as e:
        print(f"- interpolation failed: {e}")
        return np.full((grid_size, grid_size), np.nan)


### dataCube Generator

In [None]:
# datacube pipeline 
class HABNetDatacubeGenerator:
    def __init__(self, raw_dir, processed_dir, config=DATACUBE_CONFIG):
        self.raw_dir = Path(raw_dir)
        self.processed_dir = Path(processed_dir)
        self.config = config
        self.search_cache = {}

    # check if datacube already exists        
    def check_existing_datacube(self, stable_event_id):
        output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.h5"
        return output_file.exists()

    # check if we already have the nc file
    def cached_search_modis_data(self, date, spatial_bounds):
        cache_key = f"{date.strftime('%Y-%m-%d')}_{spatial_bounds['lat_min']:.2f}_{spatial_bounds['lon_min']:.2f}"
        
        if cache_key in self.search_cache:
            return self.search_cache[cache_key]
        
        granules = search_modis_l2_data(date, spatial_bounds)
        self.search_cache[cache_key] = granules
        return granules
        
    # create new datacube
    def generate_datacube_for_event(self, event_row):
        stable_event_id = event_row['STABLE_EVENT_ID']
        
        # skip if already processed
        if self.check_existing_datacube(stable_event_id):
            output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.h5"
            print(f"\nDatacube for event {stable_event_id} already exists, skipping")
            return output_file

        event_date = event_row['SAMPLE_DATE']
        event_lat = event_row['LATITUDE']
        event_lon = event_row['LONGITUDE']
        hab_label = event_row['HAB_EVENT']

        print(f"\nGenerating datacube for event: {stable_event_id}")
        print(f"Date: {event_date.strftime('%Y-%m-%d')}, Location: ({event_lat:.3f}, {event_lon:.3f})")
        print(f"HAB Event: {hab_label}")

        # setup utm projection for this location
        transformer_to_utm, transformer_from_utm, epsg_code = setup_utm_projection(event_lat, event_lon)
        center_utm_x, center_utm_y = transformer_to_utm.transform(event_lon, event_lat)

        # spatial and temporal bounds
        spatial_bounds = calculate_spatial_bounds(event_lat, event_lon, self.config['spatial_extent_km'])
        temporal_bounds = calculate_temporal_bounds(event_date, self.config['temporal_extent_days']-1)
        total_days = (temporal_bounds['end_date'] - temporal_bounds['start_date']).days

        # create h5 file
        output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.h5"
        
        # each modality gets its own group
        modality_datacubes = {}
        modality_points = {}  # store original points
        modality_points_proj = {}  # store projected points
        successful_days = 0

        with h5py.File(output_file, 'w') as h5f:
            # create ground truth group
            gt_group = h5f.create_group('GroundTruth')
            gt_group.attrs['thisLat'] = event_lat
            gt_group.attrs['thisLon'] = event_lon
            gt_group.attrs['thisCount'] = event_row['CELLCOUNT']
            gt_group.attrs['HAB_EVENT'] = hab_label
            gt_group.attrs['dayEnd'] = event_date.timestamp()
            gt_group.attrs['dayStart'] = temporal_bounds['start_date'].timestamp()
            gt_group.attrs['resolution'] = self.config['spatial_resolution_km'] * 1000  # in meters
            gt_group.attrs['distance1'] = self.config['spatial_extent_km'] * 1000  # in meters
            gt_group.attrs['projection'] = f'utm {epsg_code}'
            
            # store modality names
            modality_names = [name.encode('utf-8') for name in HABNET_MODIS_AQUA_MODALITIES]
            h5f.create_dataset('Modnames', data=modality_names)

            # process each day in time window
            for day_idx in range(total_days):
                current_date = temporal_bounds['start_date'] + timedelta(days=day_idx)
                print(f"  Day {day_idx+1}/{self.config['temporal_extent_days']} ({current_date.strftime('%Y-%m-%d')}):", end=" ")

                # search for modis data
                granules = self.cached_search_modis_data(current_date, spatial_bounds)
                if not granules:
                    print(f"No data found")
                    continue

                # download and process first granule
                file_path = download_and_cache_granule(granules[0], self.raw_dir)
                if not file_path:
                    print(f"Download failed")
                    continue

                # process all modalities from this granule
                day_success = False
                for modality in HABNET_MODIS_AQUA_MODALITIES:
                    daily_data = extract_modality_from_l2(
                        file_path, modality, spatial_bounds, 
                        transformer_to_utm, transformer_from_utm
                    )
                    
                    if daily_data and len(daily_data['values']) > 0:
                        # create modality group if it doesn't exist
                        if modality not in modality_datacubes:
                            modality_datacubes[modality] = np.full(
                                (self.config['spatial_extent_km'] // self.config['spatial_resolution_km'],
                                 self.config['spatial_extent_km'] // self.config['spatial_resolution_km'],
                                 self.config['temporal_extent_days']), np.nan
                            )
                            modality_points[modality] = []
                            modality_points_proj[modality] = []
                        
                        # reproject to regular grid
                        gridded_data = reproject_to_regular_grid(
                            daily_data['utm_x'], daily_data['utm_y'], daily_data['values'],
                            center_utm_x, center_utm_y, self.config
                        )
                        modality_datacubes[modality][:, :, day_idx] = gridded_data
                        
                        # store points with time delta
                        time_delta = day_idx  # days before event
                        points_with_time = np.column_stack([
                            daily_data['lats'], daily_data['lons'], daily_data['values'],
                            np.full(len(daily_data['values']), time_delta)
                        ])
                        points_proj_with_time = np.column_stack([
                            daily_data['utm_x'], daily_data['utm_y'], daily_data['values'],
                            np.full(len(daily_data['values']), time_delta)
                        ])
                        
                        modality_points[modality].append(points_with_time)
                        modality_points_proj[modality].append(points_proj_with_time)
                        
                        day_success = True

                if day_success:
                    successful_days += 1
                    print(f"SUCCESS: processed {len(HABNET_MODIS_AQUA_MODALITIES)} modalities")
                else:
                    print(f"No valid data for any modality")

            # save all modality data to h5 file
            for modality in modality_datacubes:
                mod_group = h5f.create_group(modality)
                mod_group.create_dataset('Ims', data=modality_datacubes[modality])
                
                # combine points from all days
                if modality_points[modality]:
                    all_points = np.vstack(modality_points[modality])
                    all_points_proj = np.vstack(modality_points_proj[modality])
                    mod_group.create_dataset('Points', data=all_points)
                    mod_group.create_dataset('PointsProj', data=all_points_proj)

        # check data completeness 
        data_completeness = successful_days / self.config['temporal_extent_days']
        print(f"\nDatacube generation complete: {successful_days}/{self.config['temporal_extent_days']} days ({data_completeness:.1%})")

        if data_completeness >= 0.4: 
            print(f"H5 datacube saved: {output_file}")
            return output_file
        else:
            print(f"Insufficient data ({data_completeness:.1%}) - datacube not saved")
            if output_file.exists():
                output_file.unlink()  # delete incomplete file
            return None


### Main pipeline

In [None]:
def run_habnet_datacube_generation():
    print("Datacube Pipeline")
    print("=" * 50)

    # setup directories
    base_dir, raw_dir, processed_dir = setup_directories()
    print(f"Data directories created in: {base_dir}")

    # setup NASA auth
    if not setup_nasa_earthdata():
        print("NASA Earthdata auth needed")
        return None

    # load and filter HAB events
    try:
        hab_events = load_and_filter_hab_events()
    except FileNotFoundError:
        print("HAB events CSV file not found. Please ensure 'habsos_20240430.csv' is available.")
        return None

    # create sample 
    mvp_events = create_sample_events(hab_events, n_events=100)

    # save events 
    mvp_file = base_dir / 'mvp_events.csv'
    mvp_events.to_csv(mvp_file, index=False)
    print(f"Events saved to: {mvp_file}")

    # init datacube generator
    generator = HABNetDatacubeGenerator(raw_dir, processed_dir)

    print(f"\nDatacube Config:")
    print(f"  Spatial: {generator.config['spatial_extent_km']}km x {generator.config['spatial_extent_km']}km")
    print(f"  Temporal: {generator.config['temporal_extent_days']} days")
    print(f"  Spatial resolution: {generator.config['spatial_resolution_km']}km (50x50 grid)")
    print(f"  Modalities: {HABNET_MODIS_AQUA_MODALITIES}")

    # generate h5 datacubes for events
    print(f"\nGenerating H5 datacubes for {len(mvp_events)} MVP events...")
    print("=" * 60)

    results = []
    start_time = datetime.now()

    for idx, (event_idx, event) in enumerate(mvp_events.iterrows()):
        print(f"\n{'='*50}")
        print(f"Processing event {idx+1}/{len(mvp_events)}: ID {event_idx}")
        
        event_start = datetime.now()
        try:
            output_file = generator.generate_datacube_for_event(event)
            event_time = (datetime.now() - event_start).total_seconds() / 60
            
            results.append({
                'event_id': event_idx,
                'hab_label': event['HAB_EVENT'],
                'output_file': output_file,
                'success': output_file is not None,
                'processing_time_min': event_time
            })
            
            print(f"Event completed in {event_time:.1f} minutes")
            
        except Exception as e:
            event_time = (datetime.now() - event_start).total_seconds() / 60
            print(f"Failed to process event {event_idx}: {e}")
            results.append({
                'event_id': event_idx,
                'hab_label': event['HAB_EVENT'],
                'output_file': None,
                'success': False,
                'processing_time_min': event_time,
                'error': str(e)
            })

    # summary
    total_time = (datetime.now() - start_time).total_seconds() / 60
    successful_results = [r for r in results if r['success']]
    
    print(f"\n" + "="*60)
    print(f"HABNET DATACUBE GENERATION SUMMARY")
    print(f"Total processing time: {total_time:.1f} minutes")
    print(f"Average time per event: {total_time/len(results):.1f} minutes")
    print(f"Total events processed: {len(results)}")
    print(f"Successful: {len(successful_results)}")
    print(f"Failed: {len(results) - len(successful_results)}")

    if successful_results:
        print(f"\nSuccessful H5 datacubes saved in: {processed_dir}")
    return results, base_dir

if __name__ == "__main__":
    print("Starting HABNet-Style Pipeline")
    print(f"Configuration: {DATACUBE_CONFIG}")
    print(f"MODIS-Aqua Modalities: {HABNET_MODIS_AQUA_MODALITIES}")

    results, data_dir = run_habnet_datacube_generation()