In [None]:
# %pip install earthaccess

#SETUP
import earthaccess
import xarray as xr
import numpy as np
from PIL import Image, ImageEnhance

In [None]:
#PARAMETERS
tspan = ("2024-09-22", "2024-09-28")
bbox = (-125, 32, -116, 38)
clouds = (0, 50)

In [None]:
#DATA SEARCH
results = earthaccess.search_data(
    short_name="PACE_OCI_L2_AOP",
    temporal=tspan,
    bounding_box=bbox,
    cloud_cover=clouds,
)
paths = earthaccess.open(results)
datatree = xr.open_datatree(paths[0])
dataset = xr.merge(datatree.to_dict().values())

In [None]:
#results

In [None]:
len(results)

In [None]:
results[0]

In [None]:
#pip install https://seabass.gsfc.nasa.gov/wiki/seabass_tools/sb_utilities-0.0.2.tar.gz


In [None]:
# List l2 flags, then build them into a dict
l2_flags_list = [
    "ATMFAIL",
    "LAND",
    "PRODWARN",
    "HIGLINT",
    "HILT",
    "HISATZEN",
    "COASTZ",
    "SPARE",
    "STRAYLIGHT",
    "CLDICE",
    "COCCOLITH",
    "TURBIDW",
    "HISOLZEN",
    "SPARE",
    "LOWLW",
    "CHLFAIL",
    "NAVWARN",
    "ABSAER",
    "SPARE",
    "MAXAERITER",
    "MODGLINT",
    "CHLWARN",
    "ATMWARN",
    "SPARE",
    "SEAICE",
    "NAVFAIL",
    "FILTER",
    "SPARE",
    "BOWTIEDEL",
    "HIPOL",
    "PRODFAIL",
    "SPARE",
]

In [None]:
L2_FLAGS = {flag: 1 << idx for idx, flag in enumerate(l2_flags_list)}

# Bailey and Werdell 2006 exclusion criteria
EXCLUSION_FLAGS = [
    #"LAND",
    #"HIGLINT",
    #"HILT",
    #"STRAYLIGHT",
    "CLDICE",
    #"ATMFAIL",
    #"LOWLW",
    #"FILTER",
    #"NAVFAIL",
    #"NAVWARN",
]


In [None]:
# Short names for earthaccess lookup
SAT_LOOKUP = {
    "PACE_AOP": "PACE_OCI_L2_AOP",
}


In [None]:
##---------------------------------------------------------------------------##
#                 Load the OCI sensor file and return F0.                     #
##---------------------------------------------------------------------------##

def get_f0(wavelengths=None, window_size=10):
    """Load the OCI sensor file and return F0.

    Defaults to returning the full table. Input obs_time to correct for the
    Earth-Sun distance.

    Parameters
    ----------
    sensor_file : str or pathlib.Path
        Path to the OCI satellite sensor file containing wavelengths and F0.
    wavelengths : array-like, optional
        Wavelengths at which to compute the average irradiance.
        If None, returns the full wavelength and irradiance table.
    window_size : int, optional
        Bandpass filter size for mean filtering to selected wavelengths, in nm.

    Returns
    -------
    tuple of np.ndarray
        A tuple containing:
        - f0_spectra : np.ndarray
            The extraterrestrial solar irradiance, in uW/cm^2/nm.
        - f0_wave : np.ndarray
            The corresponding wavelengths, in nm.

    """
    with open(OCI_SENSOR_FILE, "r") as file_in:
        for line in file_in:
            if "Nbands" in line:
                (key, nbands) = line.split("=")
                break

    wl = np.zeros(int(nbands), dtype=float)
    f0 = np.zeros(int(nbands), dtype=float)
    with open(OCI_SENSOR_FILE, "r") as file_in:
        for line in file_in:
            if "=" in line:
                (key, value) = line.split("=")
                if "Lambda" in key:
                    idx = re.findall(r"\d+", key)
                    wvlidx = int(idx[0]) - 1
                    wl[wvlidx] = float(value)
                if "F0" in key:
                    idx = re.findall(r"\d+", key)
                    wvlidx = int(idx[1]) - 1
                    f0[wvlidx] = float(value)

    if wavelengths is not None:
        f0_wave = np.array(wavelengths)
        f0_spectra = bandpass_avg(f0, wl, window_size, f0_wave)
    else:
        f0_wave = wl
        f0_spectra = f0

    return f0_spectra, f0_wave


In [None]:
##---------------------------------------------------------------------------##
#                 Apply a band-pass filter to the data.                       #
##---------------------------------------------------------------------------##
def bandpass_avg(
        data,
        input_wavelengths,
        window_size=10,
        target_wavelengths=None
        ):
    """Apply a band-pass filter to the data.

    Parameters
    ----------
    data : np.ndarray
        1D or 2D array containing the spectral data (samples x wavelengths).
        If 1D, it's assumed to be a single sample.
    input_wavelengths : np.ndarray
        1D array of wavelength values corresponding to the columns of data.
    window_size : int, optional
        Size of the window to use for averaging. Default is 10 nm.
    target_wavelengths : np.ndarray, optional
        1D array of target wavelengths for filtered values.
        If None, the input wavelengths are used.

    Returns
    -------
    np.ndarray
        1D or 2D array containing the band-pass filtered data.

    """
    data = np.atleast_2d(data)
    half_window = window_size / 2
    num_samples, num_input_wavelengths = data.shape
    if target_wavelengths is None:
        target_wavelengths = input_wavelengths

    filtered_data = np.empty((num_samples, len(target_wavelengths))) * np.nan

    for idx, target_wl in enumerate(target_wavelengths):
        start = target_wl - half_window
        end = target_wl + half_window
        cols_in_range = np.where(
            (input_wavelengths >= start) & (input_wavelengths <= end)
        )[0]
        if cols_in_range.size > 0:
            filtered_data[:, idx] = np.nanmean(data[:, cols_in_range], axis=1)

    return filtered_data if num_samples > 1 else filtered_data.flatten()

In [None]:
##---------------------------------------------------------------------------##
#         Process a dataframe to create a dictionary of data products.        #
##---------------------------------------------------------------------------##

def get_column_prods(df, type_prefix):
    """Process a dataframe to create a dictionary of data products.

    Parameters
    ----------
    df : pandas DataFrame
        Extracted dataframes from read_extract_file
    type_prefix : str
        Prefix to identify the product columns, e.g. "aoc"

    Returns
    -------
    data_dict
        dictionary mapping data product with their wavelengths and columns.

    """
    data_dict = {}
    pattern = rf"{type_prefix}_(\w+?)(\d*\.?\d+)?$"

    for col in df.columns:
        match = re.match(pattern, col)
        if match:
            product = match.group(1)
            wavelength = match.group(2) if match.group(2) else None
            if product not in data_dict:
                data_dict[product] = {"wavelengths": [], "columns": []}
            data_dict[product]["columns"].append(col)
            if wavelength:
                if "." in wavelength:
                    data_dict[product]["wavelengths"].append(float(wavelength))
                else:
                    data_dict[product]["wavelengths"].append(int(wavelength))
    return data_dict


In [None]:
##---------------------------------------------------------------------------##
#                Read SeaBASS file and returns just the data.                 #
##---------------------------------------------------------------------------##

import pandas as pd
import builtins

def read_sb(filename_sb):
    """Read SeaBASS .sb file, parse header and data, 
       then attach profile_lat, profile_lon, profile_time."""
    # 1) Load all lines
    with builtins.open(filename_sb, "r") as f:
        lines = [l.rstrip("\n") for l in f]

    # 2) Find where the header ends
    endh = next(i for i, L in enumerate(lines) if L == "/end_header")

    # 3) Parse header into a dict, but only lines with "/" **and** "="
    headers = {}
    for line in lines[:endh]:
        if not line.startswith("/") or "=" not in line:
            continue
        key, val = line[1:].split("=", 1)  # strip "/" then split
        headers[key] = val

    # 4) Read the CSV portion into a DataFrame
    df = pd.read_csv(
        filename_sb,
        skiprows=endh + 1,
        names=headers["fields"].split(","),
        na_values=headers.get("missing", "")
    )

    # 5) Build the datetime index (your existing routine)
    get_sb_datetime(df)

    # 6) Extract & clean metadata from headers
    #    Strip off any "[...]" before converting to float
    lat_str = headers["north_latitude"].split("[", 1)[0]
    lon_str = headers["east_longitude"].split("[", 1)[0]
    lat = float(lat_str)
    lon = float(lon_str)

    #    Strip "[GMT]" from the time field
    time_str = headers["start_time"].split("[", 1)[0]
    dt0 = pd.to_datetime(headers["start_date"] + " " + time_str)

    # 7) Attach them as new columns on every row
    df["profile_lat"]  = lat
    df["profile_lon"]  = lon
    df["profile_time"] = dt0

    return df



In [None]:
##---------------------------------------------------------------------------##
#     Parse datetime from different combinations of dates and times.          #
##---------------------------------------------------------------------------##

def get_sb_datetime(df):
    """Parse datetime from different combinations of dates and times."""
    if all(col in df.columns for col in ["year", "month", "day",
                                         "hour", "minute", "second"]):
        df["datetime"] = pd.to_datetime(df[["year", "month", "day",
                                            "hour", "minute", "second"]])
    elif all(col in df.columns for col in ["year", "month", "day", "time"]):
        df["datetime"] = pd.to_datetime(
            df["year"].astype(str) + df["month"].astype(str).str.zfill(2)
            + df["day"].astype(str).str.zfill(2) + ' ' + df["time"])
    elif all(col in df.columns for col in ["date", "time"]):
        df["datetime"] = pd.to_datetime(
            df["date"].astype(str) + ' ' + df["time"])
    elif all(col in df.columns for col in ["year", "month", "day"]):
        df["datetime"] = pd.to_datetime(df[["year", "month", "day"]])
    elif all(col in df.columns for col in ["date", "hour",
                                           "minute", "second"]):
        df["datetime"] = pd.to_datetime(
            df["date"].astype(str) + ' ' + df["hour"].astype(str).str.zfill(2)
            + ':' + df["minute"].astype(str).str.zfill(2) + ':'
            + df["second"].astype(str).str.zfill(2))
    else:
        print("Unrecognized date/time format in DataFrame columns."
              "\nMay be a profile, but doublecheck.")
        return

    # Reindex the dataframe with the new datetime
    df.set_index("datetime", inplace=True)


In [None]:

file_path = '/home/jovyan/shared-public/pace-hackweek/SeePACE/'
file_path += 'Hackweek_PACE-PAX_Rrs/NRL/PACE-PAX/'
file_path += 'PACE-PAX_Shearwater/archive/'
file_path += 'PVST_POL_PACE-PAX_Shearwater_above_water_radiometry_nflh_NRL_20240906_St_1_R1.sb'
df = read_sb(file_path)
df

In [None]:
# assume df is your 726×9 DataFrame with a “Wavelength” column
df_wide = df.set_index("Wavelength").T

print(df_wide.shape)   # → (9, 726)
df_wide.index.name = None          # drop the index name if you like
df_wide.columns.name = "λ (nm)"    # optional: name the wavelength axis

df_wide


In [None]:
import glob
import os
import pandas as pd

# 1) Point this at your “archive” folder containing all the .sb files
archive_dir = "/home/jovyan/shared-public/pace-hackweek/SeePACE/" \
            + "Hackweek_PACE-PAX_Rrs/NRL/PACE-PAX/PACE-PAX_Shearwater/archive"

# 2) Grab a sorted list of all the .sb paths
sb_files = sorted(glob.glob(os.path.join(archive_dir, "*.sb")))

# 3) Loop over them, reading each one and collecting into a list
df_list = []
for sb_path in sb_files:
    df = read_sb(sb_path)          # your metadata‐aware reader
    df_list.append(df)

# 4) Stack them into a single DataFrame
all_profiles = pd.concat(df_list, ignore_index=True)

# 5) Inspect
print(all_profiles.shape)   # → (number_of_profiles*726, number_of_columns)
all_profiles.head()

# 6) (Optional) Save to disk for fast reload later
all_profiles.to_csv("all_SeaBASS_profiles.csv", index=False)
# or
all_profiles.to_pickle("all_SeaBASS_profiles.pkl")



In [None]:
import pandas as pd

# From CSV (human‐readable, but a little slower to load)
df_csv = pd.read_csv("all_SeaBASS_profiles.csv")
df_csv


In [None]:
# assume df is your 26862×12 DataFrame loaded from CSV/pickle
import pandas as pd

# 0) Load your combined SeaBASS CSV (set the correct path/filename)
df = pd.read_csv("all_SeaBASS_profiles.csv")

# 1) Pivot so each profile (profile_time, lat, lon) becomes one row,
#    and each Wavelength becomes its own column holding the Rrs value.
wide = df.pivot(
    index=["profile_time", "profile_lat", "profile_lon"],
    columns="Wavelength",
    values="Rrs"
)

# 2) Turn the pivot index back into columns and rename them:
wide = (
    wide
    .reset_index()
    .rename(columns={
        "profile_time":"datetime",
        "profile_lat":"lat",
        "profile_lon":"lon"
    })
)

# 3) Convert 'datetime' to real Timestamp and split out date & time strings:
wide["datetime"] = pd.to_datetime(wide["datetime"])
wide["date"]     = wide["datetime"].dt.strftime("%Y%m%d")
wide["time"]     = wide["datetime"].dt.strftime("%H:%M:%S")

# 4) Reorder: metadata first, then wavelengths in ascending order
wls = sorted(c for c in wide.columns if isinstance(c, (int, float)))
wide = wide[["datetime", "date", "time", "lat", "lon"] + wls]

# 5) Inspect
print(wide.shape)   # → (number_of_profiles, 5 + number_of_wavelengths)
wide

In [None]:

# (re)build the wide table once and for all:
df = pd.read_csv("all_SeaBASS_profiles.csv", parse_dates=["profile_time"])
df_wide = df.pivot(
    index=["profile_time","profile_lat","profile_lon"],
    columns="Wavelength",
    values="Rrs"
).reset_index().rename(columns={
    "profile_time":"datetime",
    "profile_lat":  "lat",
    "profile_lon":  "lon"
})
# split out date/time if you need them
df_wide["date"] = df_wide["datetime"].dt.strftime("%Y%m%d")
df_wide["time"] = df_wide["datetime"].dt.strftime("%H:%M:%S")
# reorder columns
wls = sorted(c for c in df_wide.columns if isinstance(c,(int,float)))
df = df_wide[["datetime","date","time","lat","lon"] + wls]
df

In [None]:
##---------------------------------------------------------------------------##
#                             Satellite Utilities                             #
##---------------------------------------------------------------------------##
def parse_quality_flags(flag_value):
    """Parse bitwise flag into a list of flag names.

    Parameters
    ----------
    flag_value : int
        The integer representing the combined bitwise quality flags.

    Returns
    -------
    list of str
        List of flag names that are set in the flag_value.

    """
    return [
        flag_name for flag_name, value in L2_FLAGS.items()
        if (flag_value & value) != 0
    ]

In [None]:
def get_fivebyfive(file, latitude, longitude, rrs_wavelengths):
    """Get stats on 5x5 box around station coordinates of a satellite granule.

    This checks l2flags and runs statistics on valid pixels and returns their
    valid count, the coefficient of variance (cv), and the Rrs values.

    Parameters
    ----------
    file : earthaccess granule object
        Satellite granule from earthaccess.
    latitude : float
        In decimal degrees for Aeronet-OC site for matchups
    longitude : float
        In decimal degrees (negative West) for Aeronet-OC site for matchups
    rrs_wavelengths ; numpy array
        Rrs wavelengths (from wavelength_3d for OCI)

    Returns
    -------
    dict
        A dictionary of the processed 5x5 box with:
            - "sat_datetime": pd.datetime
                Datetime of the overall granule start time
            - "sat_cv": float
                Median coefficient of variation of Rrs(405nm - 570nm)
            - "sat_latitude": float
                Latitude of center pixel
            - "sat_longitude": float
                Longitude of center pixel
            - "sat_pixel_valid": float
                Number of valid pixels in 5x5 box based on l2 flags

    Notes
    -----
    This is set to use just Rrs data for the demo. As an exercise, make this
    function more generalized by adding an input for the desired product and
    removing the wavelength dependency (if not needed) as well as the cv
    calculation. This will also require refactoring the `match_data` function.
    """
    with xr.open_dataset(file, group="navigation_data") as ds_nav:
        sat_lat = ds_nav["latitude"].values
        sat_lon = ds_nav["longitude"].values

    # Calculate the Euclidean distance for 2D lat/lon arrays
    distances = np.sqrt((sat_lat - latitude) ** 2 + (sat_lon - longitude) ** 2)

    # Find the index of the minimum distance
    # Dimensions are (lines, pixels)
    min_dist_idx = np.unravel_index(np.argmin(distances), distances.shape)
    center_line, center_pixel = min_dist_idx

    # Get indices for a 5x5 box around the center pixel
    line_start = max(center_line - 2, 0)
    line_end = min(center_line + 2 + 1, sat_lat.shape[0])
    pixel_start = max(center_pixel - 2, 0)
    pixel_end = min(center_pixel + 2 + 1, sat_lat.shape[1])

    # Extract the data
    # NOTE: This is hard-coded to Rrs from an L2 AOP file.
    with xr.open_dataset(file, group="geophysical_data") as ds_data:
        rrs_data = (
            ds_data["Rrs"].isel(
                number_of_lines=slice(line_start, line_end),
                pixels_per_line=slice(pixel_start, pixel_end),
            ).values
        )
        flags_data = (
            ds_data["l2_flags"].isel(
                number_of_lines=slice(line_start, line_end),
                pixels_per_line=slice(pixel_start, pixel_end),
            ).values
        )

    # Calculate the bitwise OR of all flags in EXCLUSION_FLAGS to get a mask
    exclude_mask = sum(L2_FLAGS[flag] for flag in EXCLUSION_FLAGS)

    # Create a boolean mask
    # True means the flag value does not contain any of the EXCLUSION_FLAGS
    valid_mask = np.bitwise_and(flags_data, exclude_mask) == 0

    # Get stats and averages
    if valid_mask.any():
        rrs_valid = rrs_data[valid_mask]
        rrs_std_initial = np.std(rrs_valid, axis=0)
        rrs_mean_initial = np.mean(rrs_valid, axis=0)

        # Exclude spectra > 1.5 stdevs away
        std_mask = np.all(
            np.abs(rrs_valid - rrs_mean_initial) <= 1.5 * rrs_std_initial,
            axis=1
        )
        rrs_std = np.std(rrs_valid[std_mask], axis=0)
        rrs_mean = np.mean(rrs_valid[std_mask], axis=0).flatten()

        # Matchup criteria uses cv as median of 405-570nm
        rrs_cv = rrs_std / rrs_mean
        rrs_cv_median = np.median(
            rrs_cv[(rrs_wavelengths >= 405) & (rrs_wavelengths <= 570)]
        )
    else:
        rrs_cv_median = np.nan
        rrs_mean = np.nan * np.empty_like(rrs_wavelengths)

    # Put in dictionary of the row
    row = {
        "sat_datetime": pd.to_datetime(
            file.granule["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"],
            utc=0
        ),
        "sat_cv": rrs_cv_median,
        "sat_latitude": sat_lat[center_line, center_pixel],
        "sat_longitude": sat_lon[center_line, center_pixel],
        "sat_pixel_valid": np.sum(valid_mask),
    }

    # Add mean spectra to the row dictionary
    for wavelength, mean_value in zip(rrs_wavelengths, rrs_mean):
        key = f"sat_rrs{int(wavelength)}"
        row[key] = mean_value

    return row


In [None]:
# We do not use this. 


def get_sat_ts_matchups(
    start_date,
    end_date,
    latitude,
    longitude,
    sat="PACE_AOP",
    selected_dates=None
):
    """Make satellite timeseries of matchups from single station.

    Caution: If the date or coordinates aren't formatted correctly, it might
    pull a huge granule list and take forever to run. If it takes more than 45
    seconds to print the number of granules, just kill the process.

    Uses the earthaccess package. Defaults to the PACE OCI L2 IOP datasets,
    but other satellites can be used if they have a corresponding short_name
    in the SAT_LOOKUP dictionary.

    Workflow:
        1. Get list of matchup granules
        2. Loop through each file and:
            2a. Find closest pixel to station, extract 5x5 pixel box
            2b. Exclude pixels based on l2_flags
            2c. Filtered mean to get single spectra
            2d. Compute statistics and save data row
        3. Organize output pandas dataframe

    Parameters
    ----------
    start_date : datetime or str
        Beginning of Aeronet data to run.
    end_date : datetime or str, optional
        End of Aeronet data to run.
    latitude : float
        In decimal degrees for Aeronet-OC site for matchups
    longitude : float
        In decimal degrees (negative West) for Aeronet-OC site for matchups
    sat : str
        Name of satellite to search. Must be in SAT_LOOKUP dict constant.
    selected_dates : list of str, optional
        If given, only pull granules if the dates are in this list

    Returns
    -------
    pandas DataFrame object
        Flattened table of all satellite granule matchups.

    """
    # Look up short name from constants
    if sat not in SAT_LOOKUP.keys():
        raise ValueError(
            f"{sat} is not in the lookup dictionary. Available "
            f"sats are: {', '.join(SAT_LOOKUP)}"
        )
    short_name = SAT_LOOKUP[sat]

    # Format search parameters
    time_bounds = (f"{start_date}T00:00:00", f"{end_date}T23:59:59")

    # Run Earthaccess data search
    results = earthaccess.search_data(
        point=(longitude, latitude),
        temporal=time_bounds,
        short_name=short_name
    )
    if selected_dates is not None:
        filtered_results = [
            result
            for result in results
            if result["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"][:10]
            in selected_dates
        ]
        print(f"Filtered to {len(filtered_results)} Granules.")
        files = earthaccess.open(filtered_results)
    else:
        files = earthaccess.open(results)

    # Pull out Rrs wavelengths for easier processing
    with xr.open_dataset(files[0], group="sensor_band_parameters") as ds_bands:
        rrs_wavelengths = ds_bands["wavelength_3d"].values

    # Loop through files and process
    sat_rows = []
    for idx, file in enumerate(files):
        granule_date = pd.to_datetime(
            file.granule["umm"]["TemporalExtent"]["RangeDateTime"]["BeginningDateTime"]
        )
        print(f"Running Granule: {granule_date}")
        row = get_fivebyfive(file, latitude, longitude, rrs_wavelengths)
        sat_rows.append(row)

    return pd.DataFrame(sat_rows)

In [None]:
# New match up data (We do not use this just for test)
import pandas as pd

def match_data(
    df_sat,
    df_field,
    cv_max=0.15,
    senz_max=60.0,
    min_percent_valid=55.0,
    max_time_diff=180,
    std_max=1.5,
):
    """Create matchup dataframe based on selection criteria."""
    # Setup
    time_window = pd.Timedelta(minutes=max_time_diff)
    df_match_list = []

    # (Optionally filter field by solar zenith)
    df_field_filtered = df_field.copy()

    # 1) Pull the datetime (may be tz-aware) out of the index
    df_field_filtered["field_datetime"] = df_field_filtered.index
    # ← added: make sure field_datetime is tz-naive
    df_field_filtered["field_datetime"] = (
        pd.to_datetime(df_field_filtered["field_datetime"])
          .dt.tz_localize(None)
    )  # ← added

    # 2) Rename lat/lon into the names used below  
    df_field_filtered["field_latitude"]  = df_field_filtered["lat"]  
    df_field_filtered["field_longitude"] = df_field_filtered["lon"]

    # Filter satellite data based on cv threshold
    df_sat_filtered = df_sat[df_sat["sat_cv"] <= cv_max] 
    # Filter satellite data based on percent good pixels
    df_sat_filtered = df_sat_filtered[
        df_sat_filtered["sat_pixel_valid"] >= min_percent_valid * 25 / 100
    ]

    for _, sat_row in df_sat_filtered.iterrows():
        # 1) Strip the UTC tag off the sat timestamp:
        sat_time = sat_row["sat_datetime"].tz_convert(None)

        # 2) Now subtract your field datetimes (which are tz-naive) from that:
        time_diff = abs(df_field_filtered["field_datetime"] - sat_time)

        # 3) Continue with your masks:
        time_mask = time_diff <= time_window
        lat_mask  = abs(df_field_filtered["field_latitude"] - sat_row["sat_latitude"]) <= 0.2
        lon_mask  = abs(df_field_filtered["field_longitude"] - sat_row["sat_longitude"]) <= 0.2

        field_matches = df_field_filtered[time_mask & lat_mask & lon_mask]

        if field_matches.shape[0] > 5:
            # Filter by Standard Deviation for rrs columns
            rrs_cols = [
                col for col in field_matches.columns
                if col.startswith("field_rrs")
                and 400 <= int(col.rsplit("_rrs")[1]) <= 700
            ]
            if rrs_cols:
                mean_spectra = field_matches[rrs_cols].mean(axis=0)
                std_spectra  = field_matches[rrs_cols].std(axis=0)
                within_std   = (
                    abs(field_matches[rrs_cols] - mean_spectra) 
                    <= std_max * std_spectra
                )
                field_matches = field_matches[within_std.all(axis=1)]

        if not field_matches.empty:
            # Select the best match based on time delta
            time_diff   = abs(
                field_matches["field_datetime"] - sat_row["sat_datetime"]
            )
            best_match = field_matches.loc[time_diff.idxmin()]
            df_match_list.append({**best_match.to_dict(), **sat_row.to_dict()})

    df_match = pd.DataFrame(df_match_list)
    return df_match


In [None]:
# pull the first station’s coords out of your SeaBASS/field DataFrame
station_lat = df["lat"].iloc[0]
station_lon = df["lon"].iloc[0]

# suppose your datetime column is called 'field_datetime'
#unique_days = df.index.date             # e.g. array([datetime.date(2024,9,22), ...])
#unique_days_str = sorted({d.strftime("%Y-%m-%d") for d in unique_days}) # e.g. ['2024-09-22', '2024-09-23', ...]
# now call the satellite‐matchup routine with just those floats
#df_satellite = get_sat_ts_matchups(
   # start_date="2024-03-01",
   # end_date="2024-03-05",
   # latitude=station_lat,    # e.g. 34.2163
    #longitude=station_lon,    # e.g. -119.5980
   # sat="PACE_AOP",            # only if you want to override the default
   # selected_dates=unique_days_str
#)

df_satellite = get_sat_ts_matchups(
    start_date= "2024-09-26",
    end_date=   "2024-09-26",
    latitude=   station_lat,
    longitude=  station_lon,
    sat="PACE_AOP"
)


In [None]:
matchups = match_data(
    df_satellite,   # <-- DataFrame of sat rows, not the metadata list
    df,             # your in-situ SeaBASS DataFrame
    cv_max=0.60,
    senz_max=60.0,
    min_percent_valid=55.0,
    max_time_diff=240,
    std_max=1.5,
)
matchups

In [None]:
#Final matchup file
import pandas as pd

def match_data(
    df_sat,
    df_field,
    cv_max=0.15,
    senz_max=60.0,
    min_percent_valid=55.0,
    max_time_diff=180,
    std_max=1.5,
):
    """Create matchup dataframe based on selection criteria."""
    time_window   = pd.Timedelta(minutes=max_time_diff)
    df_match_list = []

    # 1) prepare your field table
    df_field_filtered = df_field.copy()

    # pull real datetimes out of the index
    df_field_filtered["field_datetime"] = df_field_filtered.index
    # ensure tz-naive
    df_field_filtered["field_datetime"] = (
        pd.to_datetime(df_field_filtered["field_datetime"])
          .dt.tz_localize(None)
    )

    # rename lat/lon
    df_field_filtered["field_latitude"]  = df_field_filtered["lat"]
    df_field_filtered["field_longitude"] = df_field_filtered["lon"]

    # 2) filter satellite rows
    df_sat_filtered = df_sat[df_sat["sat_cv"] <= cv_max]
    df_sat_filtered = df_sat_filtered[
        df_sat_filtered["sat_pixel_valid"] >= min_percent_valid * 25/100
    ]

    for _, sat_row in df_sat_filtered.iterrows():
        # ---- HIGHLIGHTED: drop tz once, store in sat_time ----
        sat_time = sat_row["sat_datetime"].tz_convert(None)

        # first masking by time
        time_diff = abs(df_field_filtered["field_datetime"] - sat_time)
        time_mask = time_diff <= time_window

        lat_mask = abs(df_field_filtered["field_latitude"] - sat_row["sat_latitude"]) <= 0.2
        lon_mask = abs(df_field_filtered["field_longitude"] - sat_row["sat_longitude"]) <= 0.2

        field_matches = df_field_filtered[time_mask & lat_mask & lon_mask]

        # apply your stdev filtering if you like...
        if field_matches.shape[0] > 5:
            rrs_cols = [
                c for c in field_matches.columns
                if c.startswith("field_rrs")
                and 400 <= int(c.rsplit("_rrs")[1]) <= 700
            ]
            if rrs_cols:
                mean_spectra = field_matches[rrs_cols].mean(axis=0)
                std_spectra  = field_matches[rrs_cols].std(axis=0)
                mask         = (
                    abs(field_matches[rrs_cols] - mean_spectra)
                    <= std_max * std_spectra
                )
                field_matches = field_matches[mask.all(axis=1)]

        if not field_matches.empty:
            # ---- HIGHLIGHTED: use the same tz-naive sat_time here ----
            time_diff   = abs(field_matches["field_datetime"] - sat_time)
            best_match  = field_matches.loc[time_diff.idxmin()]
            df_match_list.append({**best_match.to_dict(), **sat_row.to_dict()})

    return pd.DataFrame(df_match_list)


In [None]:
# Run code only for one date as test

import pandas as pd

# -------------------------------
# 1) Rebuild the wide in-situ table
# -------------------------------
# (a) load the long‐form CSV (one row per λ)
df_long = pd.read_csv("all_SeaBASS_profiles.csv",
                      parse_dates=["profile_time"])

# (b) pivot to one row per cast, columns=Rrs at each λ
df_wide = (
    df_long
    .pivot(index=["profile_time","profile_lat","profile_lon"],
           columns="Wavelength",
           values="Rrs")
    .reset_index()
    .rename(columns={
        "profile_time":"datetime",
        "profile_lat":"lat",
        "profile_lon":"lon"
    })
)

# (c) optional: split date & time
df_wide["date"] = df_wide["datetime"].dt.strftime("%Y%m%d")
df_wide["time"] = df_wide["datetime"].dt.strftime("%H:%M:%S")

# (d) reorder so meta first, then numeric λ in ascending order
wls     = sorted(c for c in df_wide.columns if isinstance(c,(int,float)))
df_wide = df_wide[["datetime","date","time","lat","lon"] + wls]

# confirm
print(df_wide.shape)          
print(df_wide.columns.tolist())  
# → ['datetime','date','time','lat','lon',350,351,…,1075]


# -------------------------------
# 2) Make sure datetime is the index
# -------------------------------
df_wide = df_wide.set_index("datetime")


# -------------------------------
# 3) Run your satellite‐matchup
# -------------------------------
# pull your station coords from the very first cast
station_lat = df_wide["lat"].iloc[0]
station_lon = df_wide["lon"].iloc[0]

# fetch the PACE_AOP series at that location & date
df_satellite = get_sat_ts_matchups(
    start_date="2024-09-26",
    end_date  ="2024-09-26",
    latitude  = station_lat,
    longitude = station_lon,
    sat       ="PACE_AOP"
)

# now call your existing match_data
matchups = match_data(
    df_satellite,
    df_wide,
    cv_max            = 0.60,
    senz_max          = 60.0,
    min_percent_valid = 55.0,
    max_time_diff     = 240,
    std_max           = 1.5,
)

matchups


In [None]:
# 1) we assume df_wide already exists and has index=datetime, plus lat & lon columns
#    if you need to rebuild it, see the prior pivot code.

# pull your station coords (they're constant for these casts)
station_lat = df_wide["lat"].iloc[0]
station_lon = df_wide["lon"].iloc[0]

# 2) collect the unique days you sampled
unique_days = sorted({ts.strftime("%Y-%m-%d") for ts in df_wide.index})
print("Sampling days:", unique_days)

# 3) grab **all** PACE granules on those days at your lat/lon
df_sat_all = get_sat_ts_matchups(
    start_date     = unique_days[0],
    end_date       = unique_days[-1],
    latitude       = station_lat,
    longitude      = station_lon,
    sat            = "PACE_AOP",
    selected_dates = unique_days
)

print("Found", len(df_sat_all), "satellite rows across all days.")

# 4) now do the full matchup pass in one go
matchups_all = match_data(
    df_sat_all,
    df_wide,
    cv_max            = 0.60,
    senz_max          = 60.0,
    min_percent_valid = 55.0,
    max_time_diff     = 240,
    std_max           = 1.5,
)

print("Got", len(matchups_all), "total matchups:")
matchups_all


In [None]:
# write it out
matchups_all.to_csv("matchups_all.csv", index=False)
print("Saved to", os.path.abspath("matchups_all.csv"))


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# 1. Load
df = pd.read_csv("matchups_all.csv", parse_dates=["date", "time"])

# 2. Combine date & time into a single datetime index
df["datetime"] = pd.to_datetime(df["date"].dt.strftime("%Y-%m-%d") + " " +
                                df["time"].dt.strftime("%H:%M:%S"))
df.set_index("datetime", inplace=True)

# 3. Identify your wavelengths
#    in your file, the in-situ columns are "350", "351", …, "719"
#    and the satellite means are "sat_rrs350", "sat_rrs351", …, "sat_rrs719"
wls = sorted({int(c) for c in df.columns if c.isdigit()})


In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# 1) Load your CSV
#    If “date” is YYYYMMDD as int, read as string so we can parse a format
df = pd.read_csv("matchups_all.csv", dtype={"date": str})

# 2) Build a proper datetime index
df["date"] = pd.to_datetime(df["date"], format="%Y%m%d")
df["datetime"] = pd.to_datetime(df["date"].dt.strftime("%Y-%m-%d") + " " + df["time"])
df.set_index("datetime", inplace=True)

# 3) Identify your in-situ vs. satellite columns
insitu_cols = [c for c in df.columns if c.isdigit()]
sat_cols   = [c for c in df.columns if c.startswith("sat_rrs")]

# 4) Strip off the “sat_rrs” prefix and intersect
sat_wls     = [c.replace("sat_rrs", "") for c in sat_cols]
common_wls  = sorted(int(w) for w in set(insitu_cols).intersection(sat_wls))

print("Will plot these wavelengths:", common_wls)

# 5) Loop over each common wavelength
for wl in common_wls:
    insitu_col = str(wl)
    sat_col    = f"sat_rrs{wl}"

    # grab only rows where both exist
    d = df[[insitu_col, sat_col]].dropna()
    if d.empty:
        continue

    # — Time series plot —
    plt.figure(figsize=(8, 3.5))
    plt.plot(d.index, d[insitu_col], "-o", label="In-situ",    linewidth=1)
    plt.plot(d.index, d[sat_col],    "-s", label="Satellite", linewidth=1)
    plt.title(f"Rrs @ {wl} nm")
    plt.xlabel("Date")
    plt.ylabel("Rrs (sr⁻¹)")
    plt.legend()
    plt.tight_layout()
    plt.show()

    # — 1:1 scatter plot —
    mn = min(d[insitu_col].min(), d[sat_col].min())
    mx = max(d[insitu_col].max(), d[sat_col].max())

    plt.figure(figsize=(4, 4))
    plt.scatter(d[insitu_col], d[sat_col], s=30, alpha=0.7)
    plt.plot([mn, mx], [mn, mx], "k--", lw=1)
    plt.axis("equal")
    plt.title(f"In-situ vs. Sat Rrs @ {wl} nm")
    plt.xlabel("In-situ Rrs")
    plt.ylabel("Satellite Rrs")
    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt

# 1) rcParams for small fonts, grid, etc.
plt.rcParams.update({
    "font.size":        8,
    "axes.titlesize":   9,
    "axes.labelsize":   8,
    "xtick.labelsize":  6,
    "ytick.labelsize":  6,
    "legend.fontsize":  6,
    "lines.linewidth":  1.0,
    "lines.markersize": 4,
    "figure.dpi":       300,
    "figure.figsize":  (6, 3),
    "axes.grid":        True,
    "grid.linestyle":   "--",
    "grid.linewidth":   0.4,
    "grid.alpha":       0.7,
})

# 2) Choose a valid style from plt.style.available
plt.style.use("seaborn-v0_8-whitegrid")


# 3) Your plotting loop
for wl in common_wls:
    insitu_col = str(wl)
    sat_col    = f"sat_rrs{wl}"
    d = df[[insitu_col, sat_col]].dropna()
    if d.empty:
        continue

    # — Time series plot —
    fig, ax = plt.subplots()
    ax.plot(d.index, d[insitu_col], "-o", label="In-situ")
    ax.plot(d.index, d[sat_col],    "-s", label="Satellite")
    ax.set_title(f"Rrs @ {wl} nm")
    ax.set_xlabel("Date")
    ax.set_ylabel("Rrs (sr⁻¹)")
    ax.legend(frameon=False)
    fig.autofmt_xdate(rotation=30, ha="right")
    fig.tight_layout()
    fig.savefig(f"Rrs_timeseries_{wl}nm.png")
    plt.show()

    # — 1:1 scatter plot —
    mn, mx = d.min().min(), d.max().max()
    fig, ax = plt.subplots()
    ax.scatter(d[insitu_col], d[sat_col], alpha=0.7)
    ax.plot([mn, mx], [mn, mx], "--", linewidth=0.8)
    ax.set_aspect("equal", "box")
    ax.set_title(f"In-situ vs Sat Rrs @ {wl} nm")
    ax.set_xlabel("In-situ Rrs")
    ax.set_ylabel("Satellite Rrs")
    fig.tight_layout()
    fig.savefig(f"Rrs_scatter_{wl}nm.png")
    plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import linregress, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error

# assume you already have:
#   wl          = 443
#   insitu_col  = str(wl)
#   sat_col     = f"sat_rrs{wl}"
#   d           = df[[insitu_col, sat_col]].dropna()

# 1) Compute Bland–Altman stats
paired_mean = (d[insitu_col] + d[sat_col]) / 2
diff        = d[sat_col] - d[insitu_col]

bias  = diff.mean()
sd    = diff.std(ddof=1)
loa_u = bias + 1.96 * sd
loa_l = bias - 1.96 * sd

# rank correlation for Bland–Altman annotation
rank_corr = spearmanr(d[insitu_col], d[sat_col]).correlation

# 2) Compute regression & error metrics for scatter
lr = linregress(d[insitu_col], d[sat_col])
slope, intercept, r_lin = lr.slope, lr.intercept, lr.rvalue
rmse = np.sqrt(mean_squared_error(d[sat_col],
                                  intercept + slope*d[insitu_col]))
mae  = mean_absolute_error(d[sat_col],
                           intercept + slope*d[insitu_col])

# 3) Make the two-panel figure
fig, (ax_ba, ax_sc) = plt.subplots(1, 2, figsize=(10, 4), dpi=300)

# — Bland–Altman on ax_ba —
ax_ba.scatter(paired_mean, diff, s=20, alpha=0.6)
ax_ba.axhline(bias,  color="red",    lw=1)
ax_ba.axhline(loa_u, color="forestgreen", linestyle="--", lw=1)
ax_ba.axhline(loa_l, color="forestgreen", linestyle="--", lw=1)

ax_ba.set_title("Bland–Altman Plot")
ax_ba.set_xlabel(f"Paired Mean (in-situ + sat)/2")
ax_ba.set_ylabel(f"Bias (sat–in-situ) [{wl} nm]")

# annotation text
ba_text = (
    f"Number of Points: {len(d)}\n"
    f"Mean Bias: {bias:.2e}\n"
    f"Limits of Agreement:\n [{loa_l:.2e}, {loa_u:.2e}]\n"
    f"Rank Corr: {rank_corr:.3f}"
)
ax_ba.text(
    0.05, 0.95, ba_text,
    transform=ax_ba.transAxes,
    va="top", fontsize=6, bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
)

# — Scatter + regression on ax_sc —
ax_sc.scatter(d[insitu_col], d[sat_col], s=20, alpha=0.6)
mn = min(d[insitu_col].min(), d[sat_col].min())
mx = max(d[insitu_col].max(), d[sat_col].max())

# 1:1 identity line
ax_sc.plot([mn, mx], [mn, mx], color="black", lw=1)

# best-fit line
xx = np.linspace(mn, mx, 100)
yy = intercept + slope * xx
ax_sc.plot(xx, yy, color="red", linestyle="--", lw=1)

ax_sc.set_title("Scatter Plot")
ax_sc.set_xlabel("In-situ Rrs")
ax_sc.set_ylabel("Satellite Rrs")

# annotation text
sc_text = (
    f"Slope: {slope:.3f}\n"
    f"Intercept: {intercept:.2e}\n"
    f"Linear Corr: {r_lin:.3f}\n"
    f"RMSE: {rmse:.2e}\n"
    f"MAE: {mae:.2e}"
)
ax_sc.text(
    0.05, 0.05, sc_text,
    transform=ax_sc.transAxes,
    va="bottom", fontsize=6, bbox=dict(boxstyle="round,pad=0.3", fc="white", alpha=0.7)
)

# final layout
fig.tight_layout()
plt.show()

# (Optionally) save to file
fig.savefig(f"BlandAltman_and_Scatter_{wl}nm.png", dpi=300)


In [None]:
# all the plain‐digit in-situ bands
insitu_wls = {int(c) for c in df.columns if c.isdigit()}

# all the sat_rrs### bands
sat_wls    = {int(c.replace("sat_rrs", "")) 
              for c in df.columns if c.startswith("sat_rrs")}

# their intersection
common_wls = sorted(insitu_wls & sat_wls)
print("Can plot these wavelengths:", common_wls)


In [None]:
plot_wls = [ 560, 615, 719]   # or: plot_wls = common_wls   to do them all


In [None]:


for wl in plot_wls:
    insitu_col = str(wl)
    sat_col    = f"sat_rrs{wl}"
    d = df[[insitu_col, sat_col]].dropna()
    if d.empty:
        continue

    # — Bland–Altman & scatter stats like before —
    paired_mean = 0.5*(d[insitu_col] + d[sat_col])
    diff        = d[sat_col] - d[insitu_col]
    bias        = diff.mean()
    sd          = diff.std(ddof=1)
    loa_u       = bias + 1.96*sd
    loa_l       = bias - 1.96*sd
    rank_corr   = spearmanr(d[insitu_col], d[sat_col]).correlation

    lr          = linregress(d[insitu_col], d[sat_col])
    slope, inter, r_lin = lr.slope, lr.intercept, lr.rvalue
    rmse        = np.sqrt(mean_squared_error(d[sat_col],
                                             inter + slope*d[insitu_col]))
    mae         = mean_absolute_error(d[sat_col],
                                      inter + slope*d[insitu_col])

    # — Plotting (copy your two-panel code, just replace the hard-coded wl) —
    fig, (ax_ba, ax_sc) = plt.subplots(1, 2, figsize=(10, 4), dpi=300)

    # Bland–Altman
    ax_ba.scatter(paired_mean, diff, s=20, alpha=0.6)
    ax_ba.axhline(bias,      color="red",    lw=1)
    ax_ba.axhline(loa_u,     color="forestgreen", ls="--", lw=1)
    ax_ba.axhline(loa_l,     color="forestgreen", ls="--", lw=1)
    ax_ba.set_title(f"Bland–Altman @ {wl} nm")
    ax_ba.set_xlabel("(in-situ + sat)/2")
    ax_ba.set_ylabel(f"sat−in-situ [{wl} nm]")
    ba_txt = (
        f"N: {len(d)}\n"
        f"Bias: {bias:.2e}\n"
        f"LoA: [{loa_l:.2e}, {loa_u:.2e}]\n"
        f"ρ: {rank_corr:.3f}"
    )
    ax_ba.text(0.03, 0.97, ba_txt, transform=ax_ba.transAxes,
               va="top", fontsize=6,
               bbox=dict(facecolor="white", alpha=0.7, pad=2))

    # Scatter + regression
    ax_sc.scatter(d[insitu_col], d[sat_col], s=20, alpha=0.6)
    mn, mx = d.min().min(), d.max().max()
    ax_sc.plot([mn, mx], [mn, mx], "k-", lw=1)
    xx = np.linspace(mn, mx, 100)
    ax_sc.plot(xx, inter + slope*xx, "r--", lw=1)
    ax_sc.set_title(f"Scatter @ {wl} nm")
    ax_sc.set_xlabel("in-situ Rrs")
    ax_sc.set_ylabel("satellite Rrs")
    sc_txt = (
        f"Slope: {slope:.3f}\n"
        f"Int: {inter:.2e}\n"
        f"r: {r_lin:.3f}\n"
        f"RMSE: {rmse:.2e}\n"
        f"MAE: {mae:.2e}"
    )
    ax_sc.text(0.03, 0.05, sc_txt, transform=ax_sc.transAxes,
               va="bottom", fontsize=6,
               bbox=dict(facecolor="white", alpha=0.7, pad=2))

    fig.tight_layout()
    plt.show()


In [None]:
#Whole wavelenghts

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import linregress, spearmanr
from sklearn.metrics import mean_squared_error, mean_absolute_error

# ─── 1) Load & preprocess ─────────────────────────────────────────────────────
# Make sure 'matchups_all.csv' is in your working directory.
df = pd.read_csv("matchups_all.csv", dtype={"date": str})

# Parse the date (YYYYMMDD) and combine with time into a datetime index.
df["date"] = pd.to_datetime(df["date"], format="%Y%m%d")
df["datetime"] = pd.to_datetime(df["date"].dt.strftime("%Y-%m-%d") + " " + df["time"])
df.set_index("datetime", inplace=True)

# ─── 2) Discover common wavelengths ─────────────────────────────────────────────
# In-situ bands are plain digits; satellite bands begin with 'sat_rrs'
insitu_wls = {int(c) for c in df.columns if c.isdigit()}
sat_wls    = {int(c.replace("sat_rrs", "")) for c in df.columns if c.startswith("sat_rrs")}
common_wls = sorted(insitu_wls & sat_wls)

print("Common wavelengths:", common_wls)

# ─── 3) Choose which wavelengths to plot ────────────────────────────────────────
# To plot them all, uncomment the next line:
plot_wls = common_wls

# Or pick a subset, e.g.:
# plot_wls = [353, 442, 447, 560]

# ─── 4) Set up the publication‐style defaults ───────────────────────────────────
plt.style.use("seaborn-v0_8-whitegrid")
plt.rcParams.update({
    "figure.dpi":        300,
    "figure.figsize":    (10, 4),
    "font.size":         8,
    "axes.titlesize":    10,
    "axes.labelsize":    9,
    "xtick.labelsize":   7,
    "ytick.labelsize":   7,
    "legend.fontsize":   7,
    "lines.linewidth":   1.0,
    "lines.markersize":  4,
    "axes.grid":         True,
    "grid.linestyle":    "--",
    "grid.linewidth":    0.5,
    "grid.color":        "0.7",
})

# ─── 5) Loop through wavelengths and plot ───────────────────────────────────────
for wl in plot_wls:
    insitu_col = str(wl)
    sat_col    = f"sat_rrs{wl}"
    
    # Subset and drop any missing pairs
    d = df[[insitu_col, sat_col]].dropna()
    if d.empty:
        continue

    # Bland–Altman statistics
    paired_mean = 0.5 * (d[insitu_col] + d[sat_col])
    diff        = d[sat_col] - d[insitu_col]
    bias        = diff.mean()
    sd          = diff.std(ddof=1)
    loa_u       = bias + 1.96 * sd
    loa_l       = bias - 1.96 * sd
    # regression of diff vs. mean for BA trend line
    m_ba, b_ba, *_ = linregress(paired_mean, diff)
    # rank correlation
    rank_corr = spearmanr(d[insitu_col], d[sat_col]).correlation

    # Scatter/regression statistics
    lr     = linregress(d[insitu_col], d[sat_col])
    slope  = lr.slope
    inter  = lr.intercept
    r_lin  = lr.rvalue
    rmse   = np.sqrt(mean_squared_error(d[sat_col],
                                        inter + slope * d[insitu_col]))
    mae    = mean_absolute_error(d[sat_col],
                                 inter + slope * d[insitu_col])

    # Create figure with two panels
    fig, (ax_ba, ax_sc) = plt.subplots(1, 2)

    # — Bland–Altman panel —
    ax_ba.scatter(paired_mean, diff, s=20, alpha=0.6)
    ax_ba.axhline(bias,      color="black", lw=1)
    ax_ba.axhline(loa_u,     color="black", ls="--", lw=1)
    ax_ba.axhline(loa_l,     color="black", ls="--", lw=1)
    ax_ba.plot(paired_mean, m_ba*paired_mean + b_ba,
               ls="--", color="crimson", lw=1)

    ax_ba.set_title(f"Bland–Altman @ {wl} nm")
    ax_ba.set_xlabel("Paired Mean (in-situ + sat)/2")
    ax_ba.set_ylabel(f"sat–in-situ [{wl} nm]")

    ba_txt = (
        f"N: {len(d)}\n"
        f"Bias: {bias:.2e}\n"
        f"LoA: [{loa_l:.2e}, {loa_u:.2e}]\n"
        f"ρ: {rank_corr:.3f}"
    )
    ax_ba.text(0.03, 0.97, ba_txt,
               transform=ax_ba.transAxes,
               va="top", fontsize=7,
               bbox=dict(facecolor="white", alpha=0.8, pad=2))

    # — Scatter + regression panel —
    ax_sc.scatter(d[insitu_col], d[sat_col], s=20, alpha=0.6)
    mn, mx = d.min().min(), d.max().max()
    ax_sc.plot([mn, mx], [mn, mx], color="black", lw=1)
    xx = np.linspace(mn, mx, 200)
    ax_sc.plot(xx, inter + slope*xx,
               ls="--", color="crimson", lw=1)

    ax_sc.set_title(f"Scatter @ {wl} nm")
    ax_sc.set_xlabel("in-situ Rrs")
    ax_sc.set_ylabel("satellite Rrs")

    sc_txt = (
        f"Slope: {slope:.3f}\n"
        f"Int: {inter:.2e}\n"
        f"r: {r_lin:.3f}\n"
        f"RMSE: {rmse:.2e}\n"
        f"MAE: {mae:.2e}"
    )
    ax_sc.text(0.03, 0.05, sc_txt,
               transform=ax_sc.transAxes,
               va="bottom", fontsize=7,
               bbox=dict(facecolor="white", alpha=0.8, pad=2))

    fig.tight_layout()
    # Optionally save each figure:
    fig.savefig(f"BlandAltman_Scatter_{wl}nm.png", dpi=300)
    plt.show()


In [None]:
-

In [None]:
import datetime
import os
import re
from pathlib import Path

import earthaccess
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import matplotlib.style as style
import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from matplotlib.ticker import FuncFormatter



In [None]:
# Bland-Altman/Scatterplot Constants
# Plot colors, font sizes
COLOR_PALETTE = sns.color_palette("colorblind")
COLOR_SCATTER = COLOR_PALETTE[0]
COLOR_LINE = "black"  # Was "black"
COLOR_LOA = COLOR_PALETTE[2]  # Was "green"
COLOR_FITLINE = COLOR_PALETTE[1]  # Was "magenta"
SIZE_TITLE = 24
SIZE_AXLABEL = 20
SIZE_TEXTLABEL = 14
SHOW_LEGEND = False

# Update some defaults
plt.rcParams.update({"figure.dpi": 300})
sns.set_style("ticks", rc={"figure.dpi": 300})
sns.set_context("notebook", font_scale=1.45)


In [None]:
##---------------------------------------------------------------------------##
#                              Plotting Utilities                             #
##---------------------------------------------------------------------------##


def compute_bland_altman_metrics(xx, yy, xx_unc_modl, yy_unc_modl):
    """Compute metrics for Bland-Altman plot.

    Parameters
    ----------
    xx : array
        Array of X data values.
    yy : array
        Array of Y data values.
    xx_unc_modl : float
        Uncertainty in X.
    yy_unc_modl : float
        Uncertainty in Y.

    Returns
    -------
    dict
        Dictionary of Bland-Altman metrics.

    """
    jj = (xx + yy) / 2
    kk = (yy - xx) / np.sqrt((xx_unc_modl**2) + (yy_unc_modl**2))

    meanbias = np.mean(kk)
    stdbias = np.std(kk)
    LOAlow = meanbias - stdbias
    LOAhgh = meanbias + stdbias

    ba_stat, ba_p = stats.spearmanr(jj, kk)
    ba_independ = ba_p > 0.05

    return {
        "count": kk.shape[0],
        "jj": jj,
        "kk": kk,
        "meanbias": meanbias,
        "LOAlow": LOAlow,
        "LOAhgh": LOAhgh,
        "ba_stat": ba_stat,
        "ba_p": ba_p,
        "ba_independ": ba_independ,
    }



def compute_regression_metrics(xx, yy, is_type2=False):
    """
    Compute regression metrics using ordinary least-squares.
    (We ignore the `is_type2` argument here, so you can still pass
    is_type2=True without error.)
    """
    # 1) OLS fit
    slope, intercept, r_value, p_value, std_err = sps.linregress(xx, yy)

    # 2) rank (Spearman) correlation
    spear = sps.spearmanr(xx, yy).correlation

    # 3) error metrics
    rmse = np.sqrt(np.mean((yy - xx) ** 2))
    mae  = np.mean(np.abs(yy - xx))

    return {
        "count":     len(xx),
        "slope":     slope,
        "intercept": intercept,
        "r_pear":    r_value,
        "r_spear":   spear,
        "rmse":      rmse,
        "mae":       mae,
    }



def add_text_annotations(ax, text_lines, position="top right", fontsize=SIZE_TEXTLABEL):
    """Add text annotations to the plot.

    Parameters
    ----------
    ax : Axes
        The axis to add text to.
    text_lines : list of str
        List of strings to be displayed as text.
    position : str, default 'top right'
        Position of the text on the plot.
    fontsize : int, default 12
        Font size of the text.

    """
    if position == "top right":
        x = 0.95
        y = 0.95
        ha = "right"
        va = "top"
    elif position == "top left":
        x = 0.05
        y = 0.95
        ha = "left"
        va = "top"
    elif position == "bottom left":
        x = 0.05
        y = 0.05
        ha = "left"
        va = "bottom"
    elif position == "bottom right":
        x = 0.95
        y = 0.05
        ha = "right"
        va = "bottom"

    text = "\n".join(text_lines)
    ax.text(
        x,
        y,
        text,
        transform=ax.transAxes,
        fontsize=fontsize,
        verticalalignment=va,
        horizontalalignment=ha,
        bbox=dict(facecolor="white", alpha=0.6, edgecolor="none"),
    )


def setup_plot(label):
    """Set up the plot with titles and labels.

    Parameters
    ----------
    label : str
        Title of the plot.

    Returns
    -------
    tuple
        Figure and axes of the plot.

    """
    style.use("seaborn-v0_8-whitegrid")
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6), layout="constrained")
    fig.suptitle(label, fontsize=22)
    return fig, ax1, ax2


def format_ticks(ax):
    """Format the tick labels on the axes to be more readable."""
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f"{x:.3g}"))
    ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f"{y:.3g}"))
    ax.tick_params(axis="both", which="major", width=2, length=6)
    ax.spines["top"].set_linewidth(2)
    ax.spines["right"].set_linewidth(2)
    ax.spines["left"].set_linewidth(2)
    ax.spines["bottom"].set_linewidth(2)


def plot_bland_altman(
    ax1,
    metrics,
    binscale,
    xx_unc_modl,
    x_label="x",
    y_label="y"
):
    """Plot Bland-Altman plot.

    Parameters
    ----------
    ax1 : Axes
        Axis for the Bland-Altman plot.
    metrics : dict
        Bland-Altman metrics.
    binscale : float
        Scaling factor for bin size.
    xx_unc_modl : float
        Uncertainty in X.
    x_label : string, default "x"
        String for labels for x data
    y_label : string, default "y"
        String for labels for y data

    """
    jj = metrics["jj"]
    kk = metrics["kk"]
    npoints = metrics["count"]
    meanbias = metrics["meanbias"]
    LOAlow = metrics["LOAlow"]
    LOAhgh = metrics["LOAhgh"]
    ba_independ = metrics["ba_independ"]
    ba_stat = metrics["ba_stat"]

    min_kk = meanbias - 5 * np.std(kk)
    max_kk = meanbias + 5 * np.std(kk)
    min_jj = np.min(jj)
    max_jj = np.max(jj)
    lineclr, loaclr, fitclr = (COLOR_LINE, COLOR_LOA, COLOR_FITLINE)
    ax1.scatter(jj, kk, color=COLOR_SCATTER)
    ax1.set_xlim([min_jj, max_jj])
    ax1.set_ylim([min_kk, max_kk])

    ax1.set_title("Bland-Altman plot", fontsize=SIZE_TITLE)
    ylabel = (
        "Uncertainty normalized bias"
        if xx_unc_modl != np.sqrt(0.5)
        else f"Bias, ${y_label}-{x_label}$"
    )
    ax1.set_ylabel(ylabel, fontsize=SIZE_AXLABEL)
    ax1.set_xlabel(
        f"Paired mean, $({x_label}+{y_label})/2$", fontsize=SIZE_AXLABEL
        )
    ax1.plot(
        [min_jj, max_jj], [0, 0],
        color=lineclr, linestyle="solid", linewidth=4.0
        )

    if ba_independ:
        ax1.plot(
            [min_jj, max_jj],
            [meanbias, meanbias],
            color=fitclr,
            linestyle="dashed",
            linewidth=3.0,
            label="Mean Bias",
        )
        ax1.plot(
            [min_jj, max_jj],
            [LOAlow, LOAlow],
            color=loaclr,
            linestyle="dashed",
            linewidth=2.0,
            label="Lower LOA",
        )
        ax1.plot(
            [min_jj, max_jj],
            [LOAhgh, LOAhgh],
            color=loaclr,
            linestyle="dashed",
            linewidth=2.0,
            label="Upper LOA",
        )
        ax1.fill_between(
            [min_jj, max_jj], LOAlow, LOAhgh,
            color=loaclr, alpha=0.1
            )
    else:
        ba_regress_result = stats.linregress(jj, kk)
        ba_min_fit_yy = ba_regress_result.slope * min_jj + ba_regress_result.intercept
        ba_max_fit_yy = ba_regress_result.slope * max_jj + ba_regress_result.intercept
        ax1.plot(
            [min_jj, max_jj],
            [ba_min_fit_yy, ba_max_fit_yy],
            color=fitclr,
            linestyle="dashed",
            linewidth=3.0,
            label="Linear Fit",
        )
    if SHOW_LEGEND:
        ax1.legend()
    ax1.grid(True)
    format_ticks(ax1)

    text_lines = [
        f"Number of Points: {npoints}",
        f"Mean Bias: {meanbias:.2e}",
        f"Limits of Agreement: [{LOAlow:.2e}, {LOAhgh:.2e}]",
        f"Rank Correlation: {ba_stat:.3f}",
        "Bias Independent" if ba_independ else "Bias Dependent",
    ]
    add_text_annotations(ax1, text_lines, position="bottom right")


def plot_scatter(
    ax2, xx, yy, regress_metrics, binscale, x_label="x", y_label="y"
):
    """Plot scatter plot with regression line.

    Parameters
    ----------
    ax2 : Axes
        Axis for the scatter plot.
    xx : array
        Array of X data values.
    yy : array
        Array of Y data values.
    regress_metrics : dict
        Regression metrics.
    binscale : float
        Scaling factor for bin size.
    x_label : string, default "x"
        String for labels for x data
    y_label : string, default "y"
        String for labels for y data

    """
    min_val = min(np.min(xx), np.min(yy))
    max_val = max(np.max(xx), np.max(yy))

    ax2.scatter(xx, yy, color=COLOR_SCATTER)
    ax2.set_xlim([min_val, max_val])
    ax2.set_ylim([min_val, max_val])

    ax2.set_title("Scatterplot", fontsize=SIZE_TITLE)
    ax2.set_xlabel(f"${x_label}$", fontsize=SIZE_AXLABEL)
    ax2.set_ylabel(f"${y_label}$", fontsize=SIZE_AXLABEL)
    ax2.plot(
        [min_val, max_val],
        [min_val, max_val],
        color=COLOR_LINE,
        linestyle="solid",
        linewidth=4.0,
    )

    slope = regress_metrics["slope"]
    intercept = regress_metrics["intercept"]
    min_fit_yy = slope * min_val + intercept
    max_fit_yy = slope * max_val + intercept
    ax2.plot(
        [min_val, max_val],
        [min_fit_yy, max_fit_yy],
        color=COLOR_FITLINE,
        linestyle="dashed",
        linewidth=3.0,
        label="Regression Line",
    )
    if SHOW_LEGEND:
        ax2.legend()
    ax2.grid(True)
    format_ticks(ax2)

    text_lines = [
        f"Slope: {slope:.3f}",
        f"Intercept: {intercept:.2e}",
        f"Linear Correlation: {regress_metrics['r_pear']:.3f}",
        f"Rank Correlation: {regress_metrics['r_spear']:.3f}",
        f"RMSE: {regress_metrics['rmse']:.2e}",
        f"MAE: {regress_metrics['mae']:.2e}",
    ]
    add_text_annotations(ax2, text_lines, position="bottom right")


def plot_BAvsScat(
    x_input,
    y_input,
    label="",
    saveplot=None,
    binscale=1.0,
    xx_unc_modl=np.sqrt(0.5),
    yy_unc_modl=np.sqrt(0.5),
    x_label="x",
    y_label="y",
    is_type2=True,
):
    """Routine to plot paired data as Bland-Altman and scatter plot.

    Parameters
    ----------
    x_input : array-like
        Array of X data values.
    y_input : array-like
        Corresponding array of Y data values.
    label : string, default ''
        Text label for plotting.
    saveplot : string, default None
        Set to save plot in ../output/ with the string as the filename.
    binscale : float, default 1.0
        Scaling factor for how many bins to include in a 2D histogram.
    xx_unc_modl : float, default np.sqrt(0.5)
        Uncertainty in X.
    yy_unc_modl : float, default np.sqrt(0.5)
        Uncertainty in Y.
    x_label : string, default "x"
        String for labels for x data
    y_label : string, default "y"
        String for labels for y data

    Returns
    -------
    dict
        Dictionary of computed statistics.

    """
    xx = np.asarray(x_input)
    yy = np.asarray(y_input)
    valid_indices = (
        np.isfinite(x_input)
        & np.isfinite(y_input)
        & (x_input != -999)
        & (y_input != -999)
    )
    xx = x_input[valid_indices]
    yy = y_input[valid_indices]

    ba_metrics = compute_bland_altman_metrics(xx, yy, xx_unc_modl, yy_unc_modl)
    regress_metrics = compute_regression_metrics(xx, yy, is_type2=is_type2)

    fig, ax1, ax2 = setup_plot(label)
    plot_bland_altman(ax1, ba_metrics, binscale, xx_unc_modl, x_label, y_label)
    plot_scatter(ax2, xx, yy, regress_metrics, binscale, x_label, y_label)

    if saveplot is not None:
        figpath = Path("../output") / saveplot
        fig.savefig(figpath)
        print("Saved figure to:", figpath)

    plt.show()

    return {
        "Number_of_Points": ba_metrics["count"],
        "Scale_Independence": ba_metrics["ba_independ"],
        "Mean_Bias": ba_metrics["meanbias"],
        "Limits_of_Agreement_low": (
            ba_metrics["LOAlow"] if ba_metrics["ba_independ"] else float("nan")
        ),
        "Limits_of_Agreement_high": (
            ba_metrics["LOAhgh"] if ba_metrics["ba_independ"] else float("nan")
        ),
        "Linear_Slope": regress_metrics["slope"],
        "Linear_Intercept": regress_metrics["intercept"],
        "Linear_Correlation": regress_metrics["r_pear"],
        "Rank_Correlation": regress_metrics["r_spear"],
        "RMSE": regress_metrics["rmse"],
        "MAE": regress_metrics["mae"],
    }