In [None]:
#Modules are based on OpenSARLab python environment.
#pip install earthaccess

In [None]:
pip install cartopy

In [None]:
pip install geopy

In [None]:
!pip install --upgrade xarray --user

In [None]:
import xarray as xr
from xarray import open_datatree
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import cartopy.crs as ccrs
import cartopy.feature as cfeature

from datetime import datetime, timedelta, time
from geopy.distance import geodesic

from collections import defaultdict
import re
import io
import gc
import os
import sys

In [None]:
#Import reference dataset as CSV. 
file_path = r"ReferenceDataset_filtered_full.csv"
samples_df = pd.read_csv(file_path, encoding='latin1')

In [None]:
#A NASA earthaccess login is required to download data with the earthaccess module.
import earthaccess

earthaccess.login()

In [None]:
#Define custom functions

#This function returns a string of 19 dates for each sample date: 
# i.e. the sample date, 9 preceding dates, and 9 following dates.

def get_previous_dates(center_date_str, days=10, date_format="%Y-%m-%d", both_sides=False):
    center_date = datetime.strptime(center_date_str, date_format)
    date_list = []

    if both_sides:
        for offset in range(-days, days + 1):
            date_str = (center_date + timedelta(days=offset)).strftime(date_format)
            date_list.append((date_str, offset))
    else:
        for offset in range(-days, 1): 
            date_str = (center_date + timedelta(days=offset)).strftime(date_format)
            date_list.append((date_str, offset))

    return date_list



#This function converts each date in local South Australian time to UTC in order to filter NASA granules. 
#Maximal first and last light in Australian Central Standard Time were used as the bounds for each day.
#The satellite capture times in late morning and early afternoon mean that this is still valid for daylight savings.

def local_to_utc(date_str):
  local_date = datetime.strptime(date_str, "%Y-%m-%d").date()
  first_light_local = datetime.combine(local_date, time(5, 45))
  last_light_local = datetime.combine(local_date, time(19, 15))
  # Convert to UTC
  offset = timedelta(hours=9, minutes=30)
  first_light_utc = first_light_local - offset
  last_light_utc = last_light_local - offset
  # Return time range string
  return first_light_utc, last_light_utc


#This is the function used to find the corresponding pixel index for each sample point.
#To improve efficiency, simple Euclidean distance is used to find the closest 9 pixels.
#Geodesic distance is then used to calculate the closest pixel.

def find_pixel_index_precise(lat_arr, lon_arr, target_lat, target_lon, n_candidates=9):
    lat_np = lat_arr.values
    lon_np = lon_arr.values
    # Flatten arrays
    lat_flat = lat_np.ravel()
    lon_flat = lon_np.ravel()
    # Approximate Euclidean distance
    dist_euc = np.sqrt((lat_flat - target_lat)**2 + (lon_flat - target_lon)**2)
    # Find n closest pixels
    candidate_indices = np.argpartition(dist_euc, n_candidates)[:n_candidates]
    candidates_coords = [(lat_flat[idx], lon_flat[idx]) for idx in candidate_indices]
    # Exact Geodesic distance
    target_coord = (target_lat, target_lon)
    geod_distances = [geodesic(target_coord, coord).meters for coord in candidates_coords]
    # Identify closest pixel
    best_candidate_idx = candidate_indices[np.argmin(geod_distances)]
    i, j = np.unravel_index(best_candidate_idx, lat_np.shape)

    return i, j

In [None]:
#Define overall extraction function. 
#Values are split over separate but aligned IOP, AOP and BGC granules.
def extract_past_values_joined(pt_lon, pt_lat, sample_time, days, version_num):
    granule_dates = get_previous_dates(sample_time, days, "%Y-%m-%d", both_sides=True)
    combined_rows = []

    for date_string, timestep in granule_dates:
        try:
            tspan = local_to_utc(date_string)
            bbox = (pt_lon, pt_lat, pt_lon, pt_lat)

            # IOP Extraction
            iop_list = []
            results_iop = earthaccess.search_data(
                short_name="PACE_OCI_L2_IOP_NRT",
                temporal=tspan,
                bounding_box=bbox,
                version=version_num,
                cloud_hosted=True
            )
            print(f"[{date_string}] Found {len(results_iop)} IOP granules")

            for granule in earthaccess.open(results_iop):
                try:
                    #granule_id for merging three granules and checking duplicates.
                    granule_name = granule.url().split('/')[-1]
                    granule_id = granule_name.split(".")[1]
                    print(f"Processing IOP granule: {granule_id}")
                    #aph values are in geolocation data. 
                    #The other two are required for location and time.
                    dt = open_datatree(granule)
                    ds = xr.merge([
                        dt["geophysical_data"].ds,
                        dt["navigation_data"].ds,
                        dt["sensor_band_parameters"].ds
                    ])

                    i, j = find_pixel_index_precise(ds["latitude"], ds["longitude"], pt_lat, pt_lon)
                    spectrum = ds["aph"][i, j, :]
                    aph_df = spectrum.to_dataframe(name="aph").T
                    aph_df.columns = [f"aph_{int(w)}" for w in aph_df.columns]
                    aph_df.reset_index(drop=True, inplace=True)
                    aph_df["granule_id"] = granule_id
                    aph_df["timestep"] = timestep
                    aph_df["date"] = date_string
                    iop_list.append(aph_df)

                except Exception as e:
                    print(f"Skipping IOP granule {granule.url()} due to error: {e}")
                finally:
                    del dt, ds
                    gc.collect()

            #All granules for a defined day are added.
            iop_df = pd.concat(iop_list, ignore_index=True) if iop_list else pd.DataFrame()

            # AOP Extraction
            aop_list = []
            results_aop = earthaccess.search_data(
                short_name="PACE_OCI_L2_AOP_NRT",
                temporal=tspan,
                bounding_box=bbox,
                version=version_num,
                cloud_hosted=True
            )
            print(f"[{date_string}] Found {len(results_aop)} AOP granules")

            for granule2 in earthaccess.open(results_aop):
                try:
                    granule_name2 = granule2.url().split('/')[-1]
                    granule_id2 = granule_name2.split(".")[1]
                    print(f"Processing AOP granule: {granule_id2}")

                    dt2 = open_datatree(granule2)
                    ds2 = xr.merge([
                        dt2["geophysical_data"].ds,
                        dt2["navigation_data"].ds
                    ])

                    i, j = find_pixel_index_precise(ds2["latitude"], ds2["longitude"], pt_lat, pt_lon)
                    aop_df = pd.DataFrame([{
                        "nflh": ds2["nflh"][i, j].item(),
                        "l2_flags": int(ds2["l2_flags"][i, j].item()),
                        "granule_id": granule_id2,
                        "timestep": timestep,
                        "date": date_string
                    }])
                    aop_list.append(aop_df)

                except Exception as e:
                    print(f"Skipping AOP granule {granule2.url()} due to error: {e}")
                finally:
                    del dt2, ds2
                    gc.collect()

            aop_df = pd.concat(aop_list, ignore_index=True) if aop_list else pd.DataFrame()

            # BGC Extraction
            bgc_list = []
            results_bgc = earthaccess.search_data(
                short_name="PACE_OCI_L2_BGC_NRT",
                temporal=tspan,
                bounding_box=bbox,
                version=version_num,
                cloud_hosted=True
            )
            print(f"[{date_string}] Found {len(results_bgc)} BGC granules")

            for granule3 in earthaccess.open(results_bgc):
                try:
                    granule_name3 = granule3.url().split('/')[-1]
                    granule_id3 = granule_name3.split(".")[1]
                    print(f"Processing BGC granule: {granule_id3}")

                    dt3 = open_datatree(granule3)
                    ds3 = xr.merge([
                        dt3["geophysical_data"].ds,
                        dt3["navigation_data"].ds
                    ])

                    i, j = find_pixel_index_precise(ds3["latitude"], ds3["longitude"], pt_lat, pt_lon)
                    values = {
                        "chlor_a": ds3["chlor_a"][i, j].item(),
                        "carbon_phyto": ds3["carbon_phyto"][i, j].item(),
                        "poc": ds3["poc"][i, j].item(),
                        "granule_id": granule_id3,
                        "timestep": timestep,
                        "date": date_string
                    }
                    bgc_df = pd.DataFrame([values])
                    bgc_list.append(bgc_df)

                except Exception as e:
                    print(f"Skipping BGC granule {granule3.url()} due to error: {e}")
                finally:
                    del dt3, ds3
                    gc.collect()

            bgc_df = pd.concat(bgc_list, ignore_index=True) if bgc_list else pd.DataFrame()

            # Merge on granule_id for each day
            dfs_to_merge = [df for df in [iop_df, aop_df, bgc_df] if not df.empty]
            if dfs_to_merge:
                merged_df = dfs_to_merge[0]
                for df in dfs_to_merge[1:]:
                    merged_df = pd.merge(merged_df, df, on=["granule_id", "timestep", "date"], how="outer")

                # Reorder columns
                data_cols = [c for c in merged_df.columns if c not in ["granule_id", "timestep", "date"]]
                merged_df = merged_df[data_cols + ["granule_id", "timestep", "date"]]

                combined_rows.append(merged_df)

        except Exception as e:
            print(f"Skipping {date_string} due to error: {e}")
            continue
    #Add to dataframe in memory
    final_df = pd.concat(combined_rows, ignore_index=True) if combined_rows else pd.DataFrame()
    print(f"[{sample_time}] Returning {len(final_df)} joined granule rows for past {days} day(s)")
    return final_df


In [None]:
#Test point for function. Returns valid results on day of sample.
extract_past_values_joined(138.049266, -34.530923,"2025-07-16",2,"3.1")

In [None]:
#Loop for running function for all samples. Can take 1-2 minutes per sample with current settings.

#Define columns for final dataframe.
columns = [
    'aph_400', 'aph_413', 'aph_425', 'aph_442', 'aph_460', 'aph_475', 'aph_490', 'aph_510', 
    'aph_532', 'aph_555', 'aph_583', 'aph_618', 'aph_640', 'aph_655', 'aph_665', 'aph_678', 
    'aph_701', 'nflh', 'l2_flags', 'chlor_a', 'carbon_phyto', 'poc', 
    'granule_id', 'timestep', 'date', 'sample_time', 'longitude', 'latitude', 
    'OBJECTID', 'SiteName', 'Result_Value_Numeric', 'SiteNumber', 'ReferenceID', 'Depth'
]
results_list = []
#Change to save after x samples. 1 for testing. 5 for final run.
batch_size = 5
#Save to file ""
checkpoint_file = "results.csv"
if os.path.exists(checkpoint_file):
    existing_df = pd.read_csv(checkpoint_file)
    processed_ids = set(existing_df["ReferenceID"].unique())
    print(f"Resuming from checkpoint. {len(processed_ids)} samples already processed.")
else:
    existing_df = pd.DataFrame()
    processed_ids = set()

try:
    for i, row in samples_df.iterrows():
        object_id = row["OBJECTID"]
        SiteName = row['SiteName']
        Result = row['Result_Value_Numeric']
        Site = row['SiteNumber']
        RefID = row['ReferenceID']
        Depth = row['Depth']
        #Check and skip already processed samples in case of code restart.
        if RefID in processed_ids:
            print(f"Skipping already processed sample {i+1}/{len(samples_df)} with ReferenceID {RefID}")
            continue

        pt_lon = row["Longitude"]
        pt_lat = row["Latitude"]
        sample_time = row["date_string"]

        print(f"Processing sample {i+1}/{len(samples_df)}: {sample_time}, ({pt_lat}, {pt_lon})")
        #Run extraction function for each sample. Days=9 for 19-day window. 
        #Current version is 3.1. Needs to be updated as new versions are released as old versions are removed.
        try:
            df_result = extract_past_values_joined(pt_lon, pt_lat, sample_time, days=9, version_num="3.1")

            # Empty results
            if df_result.empty:
                df_result = pd.DataFrame({col: [np.nan] for col in columns})
                print(f"No data for sample {i}, inserted null row.")

            # Add metadata
            df_result.loc[:, ['sample_time', 'longitude', 'latitude', 'OBJECTID', 'SiteName',
                              'Result_Value_Numeric', 'SiteNumber', 'ReferenceID', 'Depth']] = [
                sample_time, pt_lon, pt_lat, object_id, SiteName, Result, Site, RefID, Depth
            ]

            # Ensure correct column order
            df_result = df_result.reindex(columns=columns)
            
            results_list.append(df_result)

            # Save in batches
            if len(results_list) >= batch_size:
                batch_df = pd.concat(results_list, ignore_index=True)
                batch_df.to_csv(checkpoint_file, mode='a', index=False, 
                                header=not os.path.exists(checkpoint_file) or os.path.getsize(checkpoint_file) == 0)
                print(f"Checkpoint: Saved {len(results_list)} samples to {checkpoint_file}")
                results_list.clear()

        except Exception as e:
            print(f"Error processing sample {i}: {e}")
            continue
    
except KeyboardInterrupt:
    print("Processing interrupted. Saving current batch...")
    
# Save final results after all samples processed
if results_list:
    batch_df = pd.concat(results_list, ignore_index=True)
    batch_df.to_csv(checkpoint_file, mode='a', index=False, header=not os.path.exists(checkpoint_file))
    print(f"Final checkpoint: Saved remaining {len(results_list)} samples to {checkpoint_file}")
else:
    print("No results to save")