In [6]:
# ----------------------------------------------------------------- #
#                              MODULES                              #

# Standard Modules
import os
import glob
import pandas as pd
import numpy as np
from typing import Literal

# Third-Party Modules
import h3
import plotly.express as px
import geopandas as gpd
from shapely.geometry import Point
import pandas as pd
import numpy as np
import h3
from datetime import timedelta
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, precision_recall_curve
from xgboost import XGBClassifier
import shap
import matplotlib.pyplot as plt
import geopandas as gpd
from sklearn.metrics import confusion_matrix
from shapely.geometry import Polygon
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import (
    roc_curve,
    roc_auc_score,
    confusion_matrix,
    classification_report,
)
import matplotlib.pyplot as plt
import seaborn as sns

#                                                                   #
# ----------------------------------------------------------------- #

# ----------------------------------------------------------------- #
#                             FUNCTIONS                             #

##############
# COLLECTION


# Loan and Process Sightings
def load_and_process_sighting_data(
    directory: str,
    date_col: str,
    lat_col: str,
    lon_col: str,
    id_col: str,
    h3_resolution: int,
    start_date: str = None,
    source: Literal["TMW", "ACARTIA"] = "TMW",
) -> pd.DataFrame:
    """
    Load and process sighting data for TMW or Acartia.

    Args:
        directory (str): Path to CSV files.
        date_col (str): Column name containing datetime string.
        lat_col (str): Latitude column name.
        lon_col (str): Longitude column name.
        id_col (str): Unique identifier or countable column.
        h3_resolution (int): H3 resolution to use.
        start_date (str, optional): Filter records to start at this date.
        source (str): "TMW" or "ACARTIA", for minor formatting differences.

    Returns:
        pd.DataFrame: Aggregated sightings data with full date-grid coverage.
    """
    # Read & concat all CSVs
    data = pd.concat([pd.read_csv(path) for path in glob.glob(f"{directory}/*.csv")])
    data.columns = data.columns.str.upper()

    # Parse date and geo
    data["DATE"] = data[date_col].str[:10]
    data["LATITUDE"] = pd.to_numeric(data[lat_col], errors="coerce")
    data["LONGITUDE"] = pd.to_numeric(data[lon_col], errors="coerce")
    data = data.dropna(subset=["LATITUDE", "LONGITUDE"])

    # Calculate H3 grid
    h3_col = f"H3_GRID_{h3_resolution}"
    data[h3_col] = data.apply(
        lambda x: h3.latlng_to_cell(x["LATITUDE"], x["LONGITUDE"], h3_resolution),
        axis=1,
    )

    data["DATE"] = pd.to_datetime(data["DATE"])
    if start_date:
        data = data[data["DATE"] >= pd.to_datetime(start_date)]

    # Aggregate sightings
    data_agg = data.groupby(["DATE", h3_col], as_index=False).agg(
        SIGHTING_COUNT=(id_col, "count")
    )

    # Build full date-grid frame
    all_dates = pd.date_range(data["DATE"].min(), data["DATE"].max())
    all_grids = data[h3_col].unique()
    full_index = pd.MultiIndex.from_product(
        [all_dates, all_grids], names=["DATE", h3_col]
    )
    full_df = pd.DataFrame(index=full_index).reset_index()
    full_df["DATE"] = pd.to_datetime(full_df["DATE"])

    # Merge to ensure all date/grid combos exist
    full_data = pd.merge(full_df, data_agg, on=["DATE", h3_col], how="left")
    full_data["SIGHTING_COUNT"] = full_data["SIGHTING_COUNT"].fillna(0)

    return full_data


# H3 to Polygon Extent
def h3_to_polygon(h3_index):
    latlon = h3.cell_to_boundary(h3_index)
    return Polygon([(lon, lat) for lat, lon in latlon])  # Note lon/lat flip


# Clip Sightings to Geometry
def clip_sightings_to_geometry(
    sightings_df, h3_resolution, geometry_gdf, geometry_col="geometry"
):
    """
    Clips a sightings dataframe to a geometry (e.g., marine area) based on H3 grid resolution.

    Parameters:
    - sightings_df (pd.DataFrame): Your sightings data with H3 columns like 'H3_GRID_7'.
    - h3_resolution (int): H3 resolution to use for clipping (e.g., 7).
    - geometry_gdf (gpd.GeoDataFrame): GeoDataFrame containing geometry to clip against.
    - geometry_col (str): The name of the geometry column in geometry_gdf (default: 'geometry').

    Returns:
    - pd.DataFrame: The clipped sightings dataframe.
    """
    h3_col = f"H3_GRID_{h3_resolution}"
    tmp_df = sightings_df[[h3_col]].drop_duplicates().copy()
    tmp_df["geometry"] = tmp_df[h3_col].apply(h3_to_polygon)
    tmp_df = gpd.GeoDataFrame(tmp_df, geometry="geometry", crs="EPSG:4326")

    # Clip to the first polygon (or adjust to match your context)
    tmp_df = tmp_df.clip(geometry_gdf[geometry_col].iloc[0])

    # Merge to keep only the sightings within the clipped geometry
    clipped_sightings = pd.merge(sightings_df, tmp_df, on=h3_col)

    return clipped_sightings


#                                                                   #
# ----------------------------------------------------------------- #

### 1. Open Data

In [7]:
# Parameters
## H3 Grid Resolution for Modeling
h3_resolution = 4

## TMW Data Path
tmw_directory = "/Users/tylerstevenson/Documents/CODE/orcasalmon/data/twm"

## Acartia Data Path
acartia_directory = (
    "/Users/tylerstevenson/Documents/CODE/FindMyWhale/data/raw/sightings"
)

## Marine Area Geometries
marine_geometries_path = "/Users/tylerstevenson/Documents/CODE/FindMyWhale/data/processed/GIS/POLYGONS/SSEA_REGION_5.parquet"

In [None]:
# Data Ingest
## Open Marine Geometries
marine_area = gpd.read_parquet(marine_geometries_path)
marine_area = marine_area.dissolve()

## Open TMW
tmw_data_cleaned = load_and_process_sighting_data(
    directory=tmw_directory,
    date_col="SIGHTDATE",
    lat_col="LATITUDE",
    lon_col="LONGITUDE",
    id_col="DATE",  # or other proxy for sightings count
    h3_resolution=h3_resolution,
    source="TMW",
)

## Open Acartia
acartia_data_cleaned = load_and_process_sighting_data(
    directory=acartia_directory,
    date_col="CREATED",
    lat_col="LATITUDE",
    lon_col="LONGITUDE",
    id_col="ENTRY_ID",
    h3_resolution=h3_resolution,
    start_date="2022-01-01",
    source="ACARTIA",
)

# Conbine Sightings Data
sightings_data_raw = pd.concat([acartia_data_cleaned, tmw_data_cleaned])

## Clip to Marine Area
sightings_data_raw = clip_sightings_to_geometry(
    sightings_data_raw, h3_resolution=h3_resolution, geometry_gdf=marine_area
)

### 2. Feature Creation
#### Preprocessing
- Aggregate to Weekly Sightings
- Convert Sightings to Ratio of Total Sightings

#### Features
- Definition Signals
    - H3 Grid - Categorical Encoding
    - H3 Parent Grid - Categorical Encoding
    - Percent Grid Over Water
- Temporal Signals
    - Total Sightings
    - Number of Days with Atleast One Sighting Per Week
    - Add Week of Year
    - Add Season
    - Add Month
    - Add Prior Week of Year
    - Add Prior Season
    - Add Prior Month
    - Week Has Any Sighting Boolean Flag
    - Is Holiday Week < need to define a custom whale-likely observation holidays >
    - Is Holiday Week / Holiday Weekend < need to define a custom whale-likely observation holidays >
    - Is School Break
    - Lagged Sightings Ratio to Capture Near Term Auto-Correlation (Lag 1 - Lag n)
    - Prescence Lagged Sightings Ratio to Capture Near Term Auto-Correlation (Lag 1 - Lag n) < sighting / no sighting in previous observation>
    - Sighting Count Diff Over Lags
    - Add Some Seasonal Lag Components (These seem to be important but need to check less common grids -> Lag Periods: 28, 29, 52, 56, 57, 113) 
    - Rolling Mean
    - Rolling Std
    - Cumulative Sightings Over Last N Weeks (N = 4)
    - Relative Effort Index (weekly sightings relative to long-term weekly median)
    - Capture Multi-Scale Dependencies
        - Month Sin
        - Month Cos
        - Year Sin
        - Year Cos
    - Fourier Transform of Week, Month, Year
    - Ratio of Weeks with Observation Over Prior N Weeks
- Spatial Signals
    - Neighbor Lagged Sightings
    - Neighbor Prescence Lagged Sightings 
    - Has Active Neighbors
    - Centroid Lat/Long
- Data Lapse Check
    - Data Transition Week (e.g. TWM -> Acartia)




#### Feature Enrichment
##### 1. Preprocessing

In [89]:
import pandas as pd
import holidays
from datetime import timedelta
from tqdm import tqdm

us_holidays = holidays.UnitedStates()


def get_season(month):
    if month in [12, 1, 2]:
        return "winter"
    elif month in [3, 4, 5]:
        return "spring"
    elif month in [6, 7, 8]:
        return "summer"
    else:
        return "fall"


def add_statistical_week(df, date_col="DATE"):
    df = df.copy()
    df[date_col] = pd.to_datetime(df[date_col])

    # Set Sunday as start of week
    df["WEEK_START"] = df[date_col] - pd.to_timedelta(
        df[date_col].dt.weekday + 1, unit="D"
    )
    df["WEEK_START"] = df["WEEK_START"].dt.normalize()

    df["WEEK_YEAR"] = df["WEEK_START"].dt.year
    df["WEEK_NUMBER"] = df["WEEK_START"].dt.isocalendar().week

    return df


def is_holiday(week_start):
    week_dates = [week_start + timedelta(days=i) for i in range(7)]
    return any(date in us_holidays for date in week_dates)


def is_holiday_with_overlap(week_start):
    week_dates = [week_start + timedelta(days=i) for i in range(7)]
    for d in week_dates:
        if d in us_holidays:
            return True
        if d.weekday() in [4, 5, 6]:  # Fri/Sat/Sun
            for offset in [-1, 0, 1]:
                if (d + timedelta(days=offset)) in us_holidays:
                    return True
    return False


def is_school_break(week_start):
    return week_start.month in [6, 7, 8]  # Rough guess, refine if needed


def add_time_features(df):
    df = df.copy()

    df["WEEK_OF_YEAR"] = df["WEEK_START"].dt.isocalendar().week
    df["MONTH"] = df["WEEK_START"].dt.month
    df["SEASON"] = df["MONTH"].apply(get_season)

    # Prior periods
    prior_dates = df["WEEK_START"] - pd.Timedelta(days=7)
    df["PRIOR_WEEK_OF_YEAR"] = prior_dates.dt.isocalendar().week
    df["PRIOR_MONTH"] = prior_dates.dt.month
    df["PRIOR_SEASON"] = df["PRIOR_MONTH"].apply(get_season)

    # Holiday and seasonal break flags
    df["IS_HOLIDAY"] = df["WEEK_START"].apply(is_holiday).astype(int)
    df["IS_HOLIDAY_OVERLAP"] = (
        df["WEEK_START"].apply(is_holiday_with_overlap).astype(int)
    )
    df["IS_SCHOOL_BREAK"] = df["WEEK_START"].apply(is_school_break).astype(int)

    return df


def process_weekly_sightings(sightings_data, h3_resolution):
    """
    Processes sightings data into a weekly, grid-based format with relative ratios, temporal features,
    and H3 spatial context.

    Parameters:
    - sightings_data (pd.DataFrame): Raw sightings data with 'DATE' and H3 column at provided resolution.
    - h3_resolution (int): The resolution of the H3 grid column to use.

    Returns:
    - pd.DataFrame: Processed and aggregated sightings data.
    """
    h3_col = f"H3_GRID_{h3_resolution}"

    # 1. Add Statistical Week
    sightings_data = add_statistical_week(df=sightings_data, date_col="DATE")

    # 2. Boolean flag for any sightings per day
    sightings_data["SIGHTING_BOOL"] = (sightings_data["SIGHTING_COUNT"] > 0).astype(int)
    sightings_day_bool = sightings_data.groupby(
        ["WEEK_START", h3_col], as_index=False
    ).agg(N_DAYS_WITH_OBS=("SIGHTING_BOOL", "sum"))

    # 3. Convert to weekly count per H3 cell
    sightings_data = sightings_data.groupby(["WEEK_START", h3_col], as_index=False)[
        "SIGHTING_COUNT"
    ].sum()

    # 4. Total sightings across all grids per week
    sightings_weekly_total = sightings_data.groupby("WEEK_START", as_index=False).agg(
        TOTAL_SIGHTING_COUNT=("SIGHTING_COUNT", "sum")
    )

    # 5. Merge total sightings into main df
    sightings_data = pd.merge(
        sightings_data, sightings_weekly_total, on="WEEK_START", how="left"
    )

    # 6. Merge in daily obs flag per cell-week
    sightings_data = pd.merge(
        sightings_data, sightings_day_bool, on=["WEEK_START", h3_col], how="left"
    )

    # 7. Relative Sighting Ratio
    sightings_data["SIGHTING_RATIO"] = (
        sightings_data["SIGHTING_COUNT"] / sightings_data["TOTAL_SIGHTING_COUNT"]
    )

    # 8. Add time-based features
    sightings_data = add_time_features(sightings_data)

    # 9. Weekly flag: did any grid see a sighting this week?
    weekly_sighting_flag = (
        sightings_data.groupby("WEEK_START")["SIGHTING_COUNT"].sum().reset_index()
    )
    weekly_sighting_flag["WEEK_HAS_ANY_SIGHTING"] = (
        weekly_sighting_flag["SIGHTING_COUNT"] > 0
    ).astype(int)
    weekly_sighting_flag = weekly_sighting_flag[["WEEK_START", "WEEK_HAS_ANY_SIGHTING"]]
    sightings_data = sightings_data.merge(
        weekly_sighting_flag, on="WEEK_START", how="left"
    )

    # 10. Add H3 centroid lat/lon
    sightings_data[["CENTROID_LAT", "CENTROID_LON"]] = sightings_data[h3_col].apply(
        lambda h: pd.Series(h3.cell_to_latlng(h))
    )

    # 11. Add H3 parent cells (res 1–3)
    for res in [1, 2, 3]:
        parent_col = f"H3_GRID_PARENT_{res}"
        sightings_data[parent_col] = sightings_data[h3_col].apply(
            lambda h: h3.cell_to_parent(h, res)
        )

    return sightings_data


def add_lag_features_ratio_base(
    df,
    h3_col,
    date_col="WEEK_START",
    sighting_ratio_col="SIGHTING_RATIO",
    max_lag=5,
    seasonal_lags=[28, 29, 52, 56, 57, 113],
    rolling_windows=[3, 5, 7],
    cum_weeks=4,
    obs_flag_col=None,
    prior_weeks_obs_ratio_window=12,
):
    df = df.copy()
    df = df.sort_values([h3_col, date_col])

    # If presence bool flag not provided, create from ratio > 0
    if obs_flag_col is None:
        df["presence_bool"] = (df[sighting_ratio_col] > 0).astype(int)
        obs_flag_col = "presence_bool"

    # Lagged sightings ratio
    for lag in range(1, max_lag + 1):
        df[f"lag_ratio_{lag}"] = df.groupby(h3_col)[sighting_ratio_col].shift(lag)
        df[f"presence_lag_{lag}"] = df.groupby(h3_col)[obs_flag_col].shift(lag)

    # Lag diff (lag_ratio_1 - lag_ratio_2)
    df["lag_diff_1_2"] = df["lag_ratio_1"] - df["lag_ratio_2"]

    # Seasonal lags on sighting ratio
    for lag in seasonal_lags:
        df[f"seasonal_lag_{lag}"] = df.groupby(h3_col)[sighting_ratio_col].shift(lag)

    # Rolling mean and std of sighting ratio
    for window in rolling_windows:
        df[f"rolling_mean_{window}"] = df.groupby(h3_col)[sighting_ratio_col].transform(
            lambda x: x.shift(1).rolling(window).mean()
        )
        df[f"rolling_std_{window}"] = df.groupby(h3_col)[sighting_ratio_col].transform(
            lambda x: x.shift(1).rolling(window).std()
        )

    # Cumulative sum over last N weeks of ratio
    df[f"cumsum_{cum_weeks}"] = df.groupby(h3_col)[sighting_ratio_col].transform(
        lambda x: x.shift(1).rolling(cum_weeks).sum()
    )

    # Relative Effort Index: weekly ratio relative to long-term median (per H3)
    median_ratio = df.groupby(h3_col)[sighting_ratio_col].transform("median")
    df["relative_effort_index"] = df[sighting_ratio_col] / (
        median_ratio.replace(0, np.nan)
    )
    df["relative_effort_index"] = df["relative_effort_index"].fillna(0)

    # Cyclic month & year features
    df["month"] = df[date_col].dt.month
    df["year"] = df[date_col].dt.year

    df["month_sin"] = np.sin(2 * np.pi * df["month"] / 12)
    df["month_cos"] = np.cos(2 * np.pi * df["month"] / 12)

    year_min, year_max = df["year"].min(), df["year"].max()
    year_scaled = (df["year"] - year_min) / (year_max - year_min + 1e-9)
    df["year_sin"] = np.sin(2 * np.pi * year_scaled)
    df["year_cos"] = np.cos(2 * np.pi * year_scaled)

    # Fourier features for week, month, year
    def fourier_feats(series, n_harmonics=3):
        feats = {}
        for k in range(1, n_harmonics + 1):
            feats[f"fourier_sin_{k}"] = np.sin(2 * np.pi * k * series)
            feats[f"fourier_cos_{k}"] = np.cos(2 * np.pi * k * series)
        return pd.DataFrame(feats)

    df["week_of_year"] = df[date_col].dt.isocalendar().week.astype(int)
    df = pd.concat([df, fourier_feats(df["week_of_year"])], axis=1)
    df = pd.concat([df, fourier_feats(df["month"])], axis=1)
    df = pd.concat([df, fourier_feats(year_scaled)], axis=1)

    # Ratio of weeks with observation over prior N weeks
    df[f"obs_ratio_prior_{prior_weeks_obs_ratio_window}w"] = df.groupby(h3_col)[
        obs_flag_col
    ].transform(lambda x: x.shift(1).rolling(prior_weeks_obs_ratio_window).mean())

    return df


def add_parent_h3_features(
    df,
    h3_parent_cols=["H3_GRID_PARENT_1", "H3_GRID_PARENT_2", "H3_GRID_PARENT_3"],
    target_col="SIGHTING_RATIO",
    time_col="WEEK_START",
    lag_weeks=[1, 2, 3, 4, 5, 28, 29, 52, 56, 57, 113],
    rolling_windows=[3, 5],
    add_presence_booleans=True,
    add_rolling_means=True,
):
    import pandas as pd

    df = df.copy()

    for parent_col in h3_parent_cols:
        parent_level = parent_col.split("_")[-1].lower()  # e.g. '1', '2', '3'
        group_cols = [parent_col, time_col]

        # Step 1: Aggregate child-level target to parent grid by week
        agg = (
            df.groupby(group_cols)[target_col]
            .sum()
            .reset_index()
            .rename(columns={target_col: f"{target_col}_sum_parent_{parent_level}"})
        )

        # Step 2: Add lags
        agg = agg.sort_values(by=[parent_col, time_col])
        for lag in lag_weeks:
            agg[f"{target_col}_sum_parent_{parent_level}_lag_{lag}"] = (
                agg.groupby(parent_col)[f"{target_col}_sum_parent_{parent_level}"]
                .shift(lag)
                .fillna(0)
            )

        # Step 3: Optionally add rolling means
        if add_rolling_means:
            for win in rolling_windows:
                agg[f"{target_col}_sum_parent_{parent_level}_rollmean_{win}"] = (
                    agg.groupby(parent_col)[f"{target_col}_sum_parent_{parent_level}"]
                    .rolling(window=win, min_periods=1)
                    .mean()
                    .reset_index(level=0, drop=True)
                )

        # Step 4: Optionally add presence boolean
        if add_presence_booleans:
            for lag in lag_weeks:
                agg[f"presence_parent_{parent_level}_lag_{lag}"] = (
                    agg[f"{target_col}_sum_parent_{parent_level}_lag_{lag}"] > 0
                ).astype(int)

        # Step 5: Merge back to original df
        df = df.merge(
            agg,
            how="left",
            left_on=[parent_col, time_col],
            right_on=[parent_col, time_col],
        )

    return df

In [None]:
# 0. Copy Raw Data
sightings_data = sightings_data_raw.copy()

# 1. Process Weekly Sightings
sightings_data = process_weekly_sightings(sightings_data, h3_resolution)

# 2. Add Lagged Features
sightings_data = add_lag_features_ratio_base(
    sightings_data, h3_col=f"H3_GRID_{h3_resolution}"
)

# 3. Add Parent Grid Information
sightings_data = add_parent_h3_features(sightings_data)

In [103]:
# TODO:
# - Percent Grid Over Water (Use Major Shorelines from WSDOT, Need to Find it for CA)
# - Data Transition Week (e.g. TWM -> Acartia)??

# - Neighbor Lagged Sightings
# - Neighbor Prescence Lagged Sightings

In [None]:
# Categorically Encode H3 Grid - One-Hot Encoding might help xgboost

In [105]:
len(sightings_data.columns)

148