#### pip install pandas numpy matplotlib earthaccess netcdf4 scipy

In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import warnings
import earthaccess
import netCDF4 as nc
import os
from pathlib import Path
from scipy.interpolate import griddata
import pickle

warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# config for MVP
DATACUBE_CONFIG = {
    'spatial_extent_km': 30,      
    'temporal_extent_days': 5,    
    'spatial_resolution_km': 2,  
    'temporal_resolution_days': 1 
}

HABNET_MODALITIES = [
    'chlor_a'  
    # later we can use 'par', 'Rrs_443', 'Rrs_488', 'Rrs_531', 'Rrs_555'
]

# Gulf of Mexico  
GULF_BOUNDS = {
    'lat_min': 24.0, 'lat_max': 30.5,
    'lon_min': -88.0, 'lon_max': -80.0
}

In [3]:
# create directory structure for data storage
def setup_directories():
    base_dir = Path("habnet_mvp_data")
    raw_dir = base_dir / "raw_modis_l2"
    processed_dir = base_dir / "processed_datacubes"
    
    for directory in [base_dir, raw_dir, processed_dir]:
        directory.mkdir(exist_ok=True)
        
    return base_dir, raw_dir, processed_dir

# get lat/lon around event 
def calculate_spatial_bounds(lat, lon, extent_km=100):
    extent_deg = extent_km / 111.0  # 1 degree is about 111 km
    return {
        'lat_min': lat - extent_deg/2,
        'lat_max': lat + extent_deg/2,
        'lon_min': lon - extent_deg/2,
        'lon_max': lon + extent_deg/2
    }

# get temporal bounds
def calculate_temporal_bounds(event_date, days_before=4): 
    end_date = event_date + timedelta(days=1)  
    start_date = event_date - timedelta(days=days_before)
    return {'start_date': start_date, 'end_date': end_date}


In [4]:
# load ground truth data and filter
def load_and_filter_hab_events(csv_file='habsos_20240430.csv'):
    print("Loading ground truth data")
    
    df = pd.read_csv(csv_file)
    df['SAMPLE_DATE'] = pd.to_datetime(df['SAMPLE_DATE'])
    
    # filter for Karenia brevis
    kb_data = df[df['SPECIES'] == 'brevis'].copy()
    kb_clean = kb_data.dropna(subset=['LATITUDE', 'LONGITUDE', 'SAMPLE_DATE', 'CELLCOUNT']).copy()
    
    # Positive HAB when Karenia brevis > 50,000 cells/L
    # Negative HAB when Karenia brevis = 0 cells/L
    positive_events = kb_clean[kb_clean['CELLCOUNT'] > 50000].copy()
    positive_events['HAB_EVENT'] = 1
    
    negative_events = kb_clean[kb_clean['CELLCOUNT'] == 0].copy()
    negative_events['HAB_EVENT'] = 0
    
    hab_events = pd.concat([positive_events, negative_events], ignore_index=True)
    
    # Gulf of Mexico area
    gulf_events = hab_events[
        (hab_events['LATITUDE'] >= GULF_BOUNDS['lat_min']) &
        (hab_events['LATITUDE'] <= GULF_BOUNDS['lat_max']) &
        (hab_events['LONGITUDE'] >= GULF_BOUNDS['lon_min']) &
        (hab_events['LONGITUDE'] <= GULF_BOUNDS['lon_max'])
    ].copy()
    
    # filter closer to MODIS date (2003-2018)
    modis_start = datetime(2003, 1, 1)
    modis_end = datetime(2018, 12, 31)
    final_events = gulf_events[
        (gulf_events['SAMPLE_DATE'] >= modis_start) &
        (gulf_events['SAMPLE_DATE'] <= modis_end)
    ].copy()
    
    # create a UID 
    final_events['STABLE_EVENT_ID'] = (
        final_events['SAMPLE_DATE'].dt.strftime('%Y%m%d') + '_' +
        final_events['LATITUDE'].round(4).astype(str) + '_' +
        final_events['LONGITUDE'].round(4).astype(str) + '_' +
        final_events['CELLCOUNT'].astype(int).astype(str)
    )
    
    # remove dupes based on ID
    final_events = final_events.drop_duplicates(subset=['STABLE_EVENT_ID']).copy()
    
    # sort by date
    final_events = final_events.sort_values('SAMPLE_DATE').reset_index(drop=True)
    
    print(f"Total HAB events: {len(final_events):,}")
    print(f"Positive: {len(final_events[final_events['HAB_EVENT'] == 1]):,}")
    print(f"Negative: {len(final_events[final_events['HAB_EVENT'] == 0]):,}")
    print(f"Dates range: {final_events['SAMPLE_DATE'].min()} to {final_events['SAMPLE_DATE'].max()}")
    
    return final_events

# create a balanced sample
def create_mvp_sample(events_df, n_events=10):
    print(f"\nCreating  sample of {n_events} events")
    
    # use recent years 2015-2018
    recent_events = events_df[events_df['SAMPLE_DATE'].dt.year >= 2015].copy()
    
    # class balance
    n_positive = min(n_events // 2, len(recent_events[recent_events['HAB_EVENT'] == 1]))
    n_negative = min(n_events // 2, len(recent_events[recent_events['HAB_EVENT'] == 0]))
    
    # random state for reruns
    positive_sample = recent_events[recent_events['HAB_EVENT'] == 1].sample(
        n=n_positive, random_state=42
    )
    negative_sample = recent_events[recent_events['HAB_EVENT'] == 0].sample(
        n=n_negative, random_state=42
    )
    
    mvp_sample = pd.concat([positive_sample, negative_sample], ignore_index=True)
    mvp_sample = mvp_sample.sort_values('SAMPLE_DATE').reset_index(drop=True)
    
    print(f"MVP sample created: {len(mvp_sample)} events")
    print(f"  Positive (HAB): {len(mvp_sample[mvp_sample['HAB_EVENT'] == 1])}")
    print(f"  Negative (No HAB): {len(mvp_sample[mvp_sample['HAB_EVENT'] == 0])}")
    print(f"  Date range: {mvp_sample['SAMPLE_DATE'].min()} to {mvp_sample['SAMPLE_DATE'].max()}")
    
    return mvp_sample

In [5]:
# get MODIS L2 Ocean Color data
def search_modis_l2_data(date, spatial_bounds):
    bbox = (
        spatial_bounds['lon_min'], spatial_bounds['lat_min'],
        spatial_bounds['lon_max'], spatial_bounds['lat_max']
    )
    try:
        granules = earthaccess.search_data(
            short_name='MODISA_L2_OC',
            temporal=(date.strftime('%Y-%m-%d'), date.strftime('%Y-%m-%d')),
            bounding_box=bbox
        )
        return granules
    except Exception as e:
        print(f"Search error for {date}: {e}")
        return []

#  download satelite granule if not cached
def download_and_cache_granule(granule, raw_dir):
    granule_name = granule['umm']['GranuleUR']
    cached_file = raw_dir / f"{granule_name}.nc"
    
    if cached_file.exists():
        print(f"- using cached file: {cached_file.name[:50]}")
        return str(cached_file)
    
    try:
        print(f" - downloading new file")
        files = earthaccess.download([granule], local_path=str(raw_dir))
        if files and len(files) > 0:
            downloaded_file = Path(files[0])
            if downloaded_file != cached_file:
                downloaded_file.rename(cached_file)
            return str(cached_file)
    except Exception as e:
        print(f" - download failed: {e}")
    
    return None

# get the modalities
def extract_habnet_modalities_from_l2(file_path, spatial_bounds):
    try:
        with nc.Dataset(file_path, 'r') as ds:
            # check if required groups exist
            if 'geophysical_data' not in ds.groups or 'navigation_data' not in ds.groups:
                print(f" missing required groups in {Path(file_path).name}")
                return None

            geo_data = ds.groups['geophysical_data']
            nav_data = ds.groups['navigation_data']

            # get coords
            lats = nav_data.variables['latitude'][:]
            lons = nav_data.variables['longitude'][:]

            # spatial filtering
            lat_mask = (lats >= spatial_bounds['lat_min']) & (lats <= spatial_bounds['lat_max'])
            lon_mask = (lons >= spatial_bounds['lon_min']) & (lons <= spatial_bounds['lon_max'])
            spatial_mask = lat_mask & lon_mask

            if not np.any(spatial_mask):
                print(f" - no data in spatial bounds")
                return None

            # get modalities
            result = {
                'lats': lats[spatial_mask],
                'lons': lons[spatial_mask],
                'modalities': {}
            }

            for modality in HABNET_MODALITIES:
                if modality in geo_data.variables:
                    mod_data = geo_data.variables[modality][:]
                    if mod_data.shape == lats.shape:
                        valid_data = mod_data[spatial_mask]
                        
                        # filter valid values
                        if hasattr(valid_data, 'mask'):
                            valid_mask = ~valid_data.mask
                        else:
                            valid_mask = np.isfinite(valid_data)

                        # filter valid chlorophyll values 
                        if modality == 'chlor_a':
                            valid_mask = valid_mask & (valid_data > 0) & (valid_data < 1000) 

                        if np.any(valid_mask):
                            final_values = valid_data[valid_mask]
                            result['modalities'][modality] = final_values
                            print(f" - {modality}: {np.sum(valid_mask)} points, "
                                  f"range {np.min(final_values):.4f}-{np.max(final_values):.4f}")

            if result['modalities']:
                return result
            else:
                print(f"- no valid modality data")
                return None

    except Exception as e:
        print(f" - error with {Path(file_path).name}: {e}")
        return None

In [6]:
# datacube pipeline 
class OptimizedHABNetDatacubeGenerator:
    def __init__(self, raw_dir, processed_dir, config=DATACUBE_CONFIG):
        self.raw_dir = Path(raw_dir)
        self.processed_dir = Path(processed_dir)
        self.config = config
        self.search_cache = {}
        
# check if datacube already exists
    def check_existing_datacube(self, stable_event_id):
        output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.pkl"
        return output_file.exists()

    # check if we already have the nc file
    def cached_search_modis_data(self, date, spatial_bounds):
        cache_key = f"{date.strftime('%Y-%m-%d')}_{spatial_bounds['lat_min']:.2f}_{spatial_bounds['lon_min']:.2f}"
        
        if cache_key in self.search_cache:
            return self.search_cache[cache_key]
        
        granules = search_modis_l2_data(date, spatial_bounds)
        self.search_cache[cache_key] = granules
        return granules

# create new datacube
    def generate_datacube_for_event(self, event_row):
        stable_event_id = event_row['STABLE_EVENT_ID']
        
        # skip if already processed
        if self.check_existing_datacube(stable_event_id):
            output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.pkl"
            print(f"\nDatacube for event {stable_event_id} already exists, skipping")
            return output_file

        event_date = event_row['SAMPLE_DATE']
        event_lat = event_row['LATITUDE']
        event_lon = event_row['LONGITUDE']
        hab_label = event_row['HAB_EVENT']

        print(f"\nGenerating datacube for event: {stable_event_id}")
        print(f"Date: {event_date.strftime('%Y-%m-%d')}, Location: ({event_lat:.3f}, {event_lon:.3f})")
        print(f"HAB Event: {hab_label}")

        # spatial and temporal bounds
        spatial_bounds = calculate_spatial_bounds(event_lat, event_lon, self.config['spatial_extent_km'])
        temporal_bounds = calculate_temporal_bounds(event_date, self.config['temporal_extent_days']-1)
        total_days = (temporal_bounds['end_date'] - temporal_bounds['start_date']).days



        datacube_shape = (
            self.config['spatial_extent_km'] // self.config['spatial_resolution_km'],
            self.config['spatial_extent_km'] // self.config['spatial_resolution_km'],
            self.config['temporal_extent_days']
        )

        # dict to store each modality
        modality_datacubes = {}
        for modality in HABNET_MODALITIES:
            modality_datacubes[modality] = np.full(datacube_shape, np.nan)

        successful_days = 0

        # process each day in time window
        for day_idx in range(total_days):
            current_date = temporal_bounds['start_date'] + timedelta(days=day_idx)
            print(f"  Day {day_idx+1}/{self.config['temporal_extent_days']} ({current_date.strftime('%Y-%m-%d')}):", end=" ")

            # check if we have cached MODIS data
            granules = self.cached_search_modis_data(current_date, spatial_bounds)
            if not granules:
                print(f"No data found")
                continue

            # download and process first cached  granule
            file_path = download_and_cache_granule(granules[0], self.raw_dir)
            if not file_path:
                print(f"Download failed")
                continue

            # process modalities
            daily_data = extract_habnet_modalities_from_l2(file_path, spatial_bounds)
            if daily_data and daily_data['modalities']:
                # project each modality to grid
                for modality, values in daily_data['modalities'].items():
                    if len(values) > 0:
                        gridded_data = self._reproject_to_regular_grid(
                            daily_data['lats'], daily_data['lons'], values, spatial_bounds
                        )
                        modality_datacubes[modality][:, :, day_idx] = gridded_data

                successful_days += 1
                print(f"SUCCESS: {len(daily_data['modalities'])} modalities processed")
            else:
                print(f"No valid modality data")

        # check data completeness
        data_completeness = successful_days / self.config['temporal_extent_days']
        print(f"\nDatacube generation complete: {successful_days}/{self.config['temporal_extent_days']} days ({data_completeness:.1%})")

        # only save if we have decent data completeness (4 days for now)
        if data_completeness >= 0.8:
            datacube_data = {
                'datacubes': modality_datacubes,
                'metadata': {
                    'stable_event_id': stable_event_id,
                    'date': event_date,
                    'lat': event_lat,
                    'lon': event_lon,
                    'hab_label': hab_label,
                    'cell_count': event_row['CELLCOUNT'],
                    'spatial_bounds': spatial_bounds,
                    'temporal_bounds': temporal_bounds,
                    'config': self.config,
                    'data_completeness': data_completeness,
                    'successful_days': successful_days,
                    'modalities': list(modality_datacubes.keys()),
                    'generation_date': datetime.now()
                }
            }

            # save datacube
            output_file = self.processed_dir / f"habnet_datacube_{stable_event_id}.pkl"
            with open(output_file, 'wb') as f:
                pickle.dump(datacube_data, f)

            print(f"Datacube saved: {output_file}")
            return output_file
        else:
            print(f"Insufficient data ({data_completeness:.1%}) - datacube not saved")
            return None

    
    def _reproject_to_regular_grid(self, lats, lons, values, spatial_bounds):
        # target grid
        grid_size = self.config['spatial_extent_km'] // self.config['spatial_resolution_km']
        target_lats = np.linspace(
            spatial_bounds['lat_min'], spatial_bounds['lat_max'], grid_size
        )
        target_lons = np.linspace(
            spatial_bounds['lon_min'], spatial_bounds['lon_max'], grid_size
        )
        target_lons_grid, target_lats_grid = np.meshgrid(target_lons, target_lats)

        # get source points for interpolation
        source_points = np.column_stack([lons.ravel(), lats.ravel()])
        target_points = np.column_stack([target_lons_grid.ravel(), target_lats_grid.ravel()])

        # remove invalid values
        valid_mask = np.isfinite(values)
        if np.sum(valid_mask) < 4:  # Need at least 4 points for triangulation
            return np.full((grid_size, grid_size), np.nan)

        source_points_valid = source_points[:len(values)][valid_mask]
        source_values_valid = values[valid_mask]

        try:
            # interpolate to grid with lerp
            interpolated = griddata(
                source_points_valid, source_values_valid, target_points,
                method='linear', fill_value=np.nan
            )
            
            # fill leftoever NaNs with nearest neighbor if we can
            if np.any(np.isnan(interpolated)) and len(source_values_valid) >= 1:
                interpolated_nn = griddata(
                    source_points_valid, source_values_valid, target_points,
                    method='nearest', fill_value=np.nan
                )
                nan_mask = np.isnan(interpolated)
                interpolated[nan_mask] = interpolated_nn[nan_mask]

            gridded_data = interpolated.reshape(target_lons_grid.shape)
            return gridded_data
        except Exception as e:
            print(f"- interpolation failed: {e}")
            return np.full((grid_size, grid_size), np.nan)

In [7]:
# nasa earth auth
def setup_nasa_earthdata():
    print("Setting up NASA Earthdata auth")
    try:
        auth = earthaccess.login()
        if auth:
            print(" - NASA Earthdata auth successful!")
            return True
        else:
            print(" - Auth failed")
            return False
    except Exception as e:
        print(f"  Authentication error: {e}")
        return False

def run_datacube_generation():
    print("HAB detection Chl-a Datacube Pipeline")
    print("=" * 50)

    # setup directories
    base_dir, raw_dir, processed_dir = setup_directories()
    print(f"Data directories created in: {base_dir}")


    # setup NASA authentication
    if not setup_nasa_earthdata():
        print("Cannot proceed without NASA Earthdata auth")
        return None

    # load and filter HAB events
    try:
        hab_events = load_and_filter_hab_events()
    except FileNotFoundError:
        print("HAB events CSV file not found. Please ensure 'habsos_20240430.csv' is available.")
        return None

    # create sample 1k for now
    mvp_events = create_mvp_sample(hab_events, n_events=1000)

    # Save MVP events for reference
    mvp_file = base_dir / 'mvp_events.csv'
    mvp_events.to_csv(mvp_file, index=False)
    print(f"MVP events saved to: {mvp_file}")

    # init optimized datacube generator
    generator = OptimizedHABNetDatacubeGenerator(raw_dir, processed_dir)

    print(f"\nDatacube Config:")
    print(f"  Spatial: {generator.config['spatial_extent_km']}km x {generator.config['spatial_extent_km']}km")
    print(f"  Temporal: {generator.config['temporal_extent_days']} days")
    print(f"  Spatial resolution: {generator.config['spatial_resolution_km']}km")
    print(f"  Modalities: {HABNET_MODALITIES}")

    # guess processing time
    base_time_per_event = 1.0  # about a min per event
    estimated_total = len(mvp_events) * base_time_per_event

    # Generate datacubes for events
    print(f"\nGenerating datacubes for {len(mvp_events)} MVP events...")
    print("=" * 60)

    results = []
    start_time = datetime.now()

    for idx, (event_idx, event) in enumerate(mvp_events.iterrows()):
        print(f"\n{'='*50}")
        print(f"Processing event {idx+1}/{len(mvp_events)}: ID {event_idx}")
        
        event_start = datetime.now()
        try:
            output_file = generator.generate_datacube_for_event(event)
            event_time = (datetime.now() - event_start).total_seconds() / 60
            
            results.append({
                'event_id': event_idx,
                'hab_label': event['HAB_EVENT'],
                'output_file': output_file,
                'success': output_file is not None,
                'processing_time_min': event_time
            })
            
            print(f"Event completed in {event_time:.1f} minutes")
            
        except Exception as e:
            event_time = (datetime.now() - event_start).total_seconds() / 60
            print(f"Failed to process event {event_idx}: {e}")
            results.append({
                'event_id': event_idx,
                'hab_label': event['HAB_EVENT'],
                'output_file': None,
                'success': False,
                'processing_time_min': event_time,
                'error': str(e)
            })

    # summary
    total_time = (datetime.now() - start_time).total_seconds() / 60
    successful_results = [r for r in results if r['success']]
    
    print(f"\n" + "="*60)
    print(f"DATACUBE GENERATION SUMMARY")
    print(f"Total processing time: {total_time:.1f} minutes")
    print(f"Average time per event: {total_time/len(results):.1f} minutes")
    print(f"Total events processed: {len(results)}")
    print(f"Successful: {len(successful_results)}")
    print(f"Failed: {len(results) - len(successful_results)}")

    if successful_results:
        print(f"\nSuccessful datacubes saved in: {processed_dir}")
    return results, base_dir

if __name__ == "__main__":
    print("Starting HABNet MVP Pipeline")
    print(f"Configuration: {DATACUBE_CONFIG}")
    print(f"Modalities: {HABNET_MODALITIES}")

    results, data_dir = run_datacube_generation()

Starting HABNet MVP Pipeline
Configuration: {'spatial_extent_km': 30, 'temporal_extent_days': 5, 'spatial_resolution_km': 2, 'temporal_resolution_days': 1}
Modalities: ['chlor_a']
HAB detection Chl-a Datacube Pipeline
Data directories created in: habnet_mvp_data
Setting up NASA Earthdata auth


Enter your Earthdata Login username:  danieli1245
Enter your Earthdata password:  ········


 - NASA Earthdata auth successful!
Loading ground truth data
Total HAB events: 86,160
Positive: 8,736
Negative: 77,424
Dates range: 2003-01-02 00:00:00 to 2018-12-31 00:00:00

Creating  sample of 1000 events
MVP sample created: 1000 events
  Positive (HAB): 500
  Negative (No HAB): 500
  Date range: 2015-01-18 00:00:00 to 2018-12-17 00:00:00
MVP events saved to: habnet_mvp_data\mvp_events.csv

Datacube Config:
  Spatial: 30km x 30km
  Temporal: 5 days
  Spatial resolution: 2km
  Modalities: ['chlor_a']

Generating datacubes for 1000 MVP events...

Processing event 1/1000: ID 0

Datacube for event 20150118_24.7967_-80.7839_0 already exists, skipping
Event completed in 0.0 minutes

Processing event 2/1000: ID 1

Generating datacube for event: 20150119_27.3338_-82.5794_0
Date: 2015-01-19, Location: (27.334, -82.579)
HAB Event: 0
  Day 1/5 (2015-01-14): - using cached file: MODISA_L2_OC_AQUA_MODIS.20150114T182001.L2.OC.nc_2
- no valid modality data
No valid modality data
  Day 2/5 (2015-01

KeyboardInterrupt: 