In [47]:
%load_ext autoreload
%autoreload 2
import sys

# instead of creating a package using setup.py or building from a docker/singularity file,
# import the sister directory of src code to be called on in notebook.
# This keeps the notebook free from code to only hold visualizations and is easier to test
# It also helps keep the state of variables clean such that cells aren't run out of order with a mysterious state
sys.path.append("..")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [48]:
# from data import hrrr_data, nam_data
import pandas as pd
import numpy as np
from datetime import timedelta
import os
import matplotlib.pyplot as plt
import gc
import matplotlib.dates as mdates
from datetime import datetime
import shutil
import statistics as st
from pathlib import Path

In [49]:
nysm_df = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
stations = nysm_df["stid"].unique()
stations

array(['ADDI', 'ANDE', 'BATA', 'BEAC', 'BELD', 'BELL', 'BELM', 'BERK',
       'BING', 'BKLN', 'BRAN', 'BREW', 'BROC', 'BRON', 'BROO', 'BSPA',
       'BUFF', 'BURD', 'BURT', 'CAMD', 'CAPE', 'CHAZ', 'CHES', 'CINC',
       'CLAR', 'CLIF', 'CLYM', 'COBL', 'COHO', 'COLD', 'COPA', 'COPE',
       'CROG', 'CSQR', 'DELE', 'DEPO', 'DOVE', 'DUAN', 'EAUR', 'EDIN',
       'EDWA', 'ELDR', 'ELLE', 'ELMI', 'ESSX', 'FAYE', 'FRED', 'GABR',
       'GFAL', 'GFLD', 'GROT', 'GROV', 'HAMM', 'HARP', 'HARR', 'HART',
       'HERK', 'HFAL', 'ILAK', 'JOHN', 'JORD', 'KIND', 'LAUR', 'LOUI',
       'MALO', 'MANH', 'MEDI', 'MEDU', 'MORR', 'NBRA', 'NEWC', 'NHUD',
       'OLDF', 'OLEA', 'ONTA', 'OPPE', 'OSCE', 'OSWE', 'OTIS', 'OWEG',
       'PENN', 'PHIL', 'PISE', 'POTS', 'QUEE', 'RAND', 'RAQU', 'REDF',
       'REDH', 'ROXB', 'RUSH', 'SARA', 'SBRI', 'SCHA', 'SCHO', 'SCHU',
       'SCIP', 'SHER', 'SOME', 'SOUT', 'SPRA', 'SPRI', 'STAT', 'STEP',
       'STON', 'SUFF', 'TANN', 'TICO', 'TULL', 'TUPP', 'TYRO', 'VOOR',
      

In [50]:
q_path = "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/20250508"

end_path = (
    "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/hrrr_prospectus_v2"
)


files = os.listdir(q_path)

In [51]:
for s in stations:
    for f in files:
        if s in f:
            src = os.path.join(q_path, f)
            dst = os.path.join(end_path, f)

            if os.path.isdir(src):
                # Ensure the destination directory exists
                if not os.path.exists(dst):
                    os.makedirs(dst)

                for item in os.listdir(src):
                    src_item = os.path.join(src, item)
                    dst_item = os.path.join(dst, item)

                    if not os.path.exists(dst_item):
                        if os.path.isdir(src_item):
                            shutil.copytree(src_item, dst_item)
                        else:
                            shutil.copy2(src_item, dst_item)
                    else:
                        print(f"Skipped {dst_item} (already exists)")
            else:
                if not os.path.exists(dst):
                    shutil.copy2(src, dst)
                else:
                    print(f"Skipped {dst} (already exists)")

In [2]:
end_path = "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/oksm_hrrr"

for root, dirs, files in os.walk(end_path):
    for name in files:
        if "full" in name:
            file_path = os.path.join(root, name)
            os.remove(file_path)
            print(f"Removed file: {file_path}")
    for name in dirs:
        if "full" in name:
            dir_path = os.path.join(root, name)
            shutil.rmtree(dir_path)
            print(f"Removed directory: {dir_path}")

In [None]:
directory_path = "/home/aevans/nwp_bias/src/machine_learning/data/parent_models/HRRR/exclusion_buffer/"

files_ = os.listdir(directory_path)

master = pd.DataFrame()
for f in files_:
    if "csv" in f:
        df = pd.read_csv(f"{directory_path}{f}")
        master = pd.concat([master, df])

In [None]:
master

In [None]:
stations = master["station"].unique()

plt.figure(figsize=(10, 6))

for s in stations:
    df_ = master[master["station"] == s]
    plt.plot(df_["buffer"].values, df_["mae"].values, label=s)

plt.xlabel("Buffer (km)")
plt.ylabel("MAE")
plt.title("MAE vs Buffer for Each Station")
plt.legend(
    title="Station", bbox_to_anchor=(1.05, 1), loc="upper left"
)  # Moves legend outside plot
plt.tight_layout()
plt.grid(True)
plt.show()

In [None]:
c = "East Central"

oksm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/oksm.csv")
df = oksm_clim[oksm_clim["Climate_division"] == c]
stations = df["stid"].unique()

In [None]:
df

In [None]:
stations

In [None]:
oksmdf = pd.read_parquet(
    "/home/aevans/nwp_bias/data/oksm/oksm_1H_obs_2018.parquet"
).reset_index()
oksmdf = oksmdf[oksmdf["station"].isin(stations)]
oksmdf

In [None]:
hrrrdf = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/oksm/hrrr_data/fh01/HRRR_2018_01_direct_compare_to_oksm_sites_mask_water.parquet"
).reset_index()
hrrrdf = hrrrdf[hrrrdf["station"].isin(stations)]
hrrrdf

In [None]:
clim_div = station_to_climdiv.get("VOOR")
clim_div

In [None]:
selected_df.to_csv(
    "/home/aevans/nwp_bias/src/machine_learning/notebooks/random_nysm_by_climdiv.csv"
)

In [None]:
df_filt = df[df["station"] == "ACME"]
# plt.plot(df_filt['time_1H'], df_filt['precip_total'])
df_filt.head()

In [None]:
# Path to your directory of parquet files
path = "/home/aevans/nwp_bias/src/landtype/NY_cartopy/oksm_v3/"
input_dir = os.listdir(path)

# Load and concatenate all Parquet files
for file in input_dir:
    print(file)
    df = pd.read_parquet(f"{path}{file}").reset_index()
    # Ensure 'time' column is datetime
    df["time"] = pd.to_datetime(df["time"])
    # Format: 'YYYY-MM-DD HH:MM:00'
    df["time"] = df["time"].apply(lambda t: f"{t:%Y-%m-%d} {t:%H}:{t:%M}:00")
    df.set_index(["station", "time"]).sort_index()
    df.to_parquet(f"{path}{file}")
    print(file, "done!")

In [None]:
df = pd.read_parquet("/home/aevans/nwp_bias/data/oksm/oksm_1H_obs_2018.parquet")
df

In [None]:
df

In [None]:
for c in df.columns:
    print(df[c].dtype)

In [None]:
for s in df["station"].unique():
    df_filt = df[df["station"] == s]
    for c in df_filt:
        if c == "station" or c == "time_1H":
            continue
        else:
            print(c)
            print("MAX", df_filt[c].max())
            print("MIN", df_filt[c].min())
            print("MEAN", st.mean(df_filt[c]))

In [None]:
directory_path = "/home/aevans/nwp_bias/src/landtype/NY_cartopy/"
files = os.listdir(directory_path)

for d in files:
    if d.endswith(".parquet") or d.endswith(".mts"):
        os.remove(os.path.join(directory_path, d))

In [None]:
import os
import pandas as pd

directory = "/home/aevans/nwp_bias/src/landtype/NY_cartopy/"

# List all .mts files
mts_files = [f for f in os.listdir(directory) if f.endswith(".mts")]

for mts_file in mts_files:
    mts_path = os.path.join(directory, mts_file)
    try:
        # Step 1: Read the .mts file — customize if needed
        with open(mts_path, "r") as f:
            lines = f.readlines()

        # Optional: Parse lines into a list of lists
        data = [line.strip().split() for line in lines]

        # Optional: Create column names or infer them
        df = pd.DataFrame(data)
        df.columns = [f"col{i}" for i in range(df.shape[1])]

        # Step 2: Save to .parquet
        parquet_path = os.path.join(directory, mts_file.replace(".mts", ".parquet"))
        df.to_parquet(parquet_path, index=False)

        print(f"✅ Converted {mts_file} → {parquet_path}")
    except Exception as e:
        print(f"❌ Failed to convert {mts_file}: {e}")

In [None]:
okdf = pd.read_parquet("/home/aevans/nwp_bias/data/oksm/oksm_1H_obs_2018.parquet")
okdf

In [None]:
def check_missing_files(base_path):
    required_files = [
        "fh18_u_total_HRRR_ml_output_linear",
        "fh18_t2m_HRRR_ml_output_linear",
        "fh18_tp_HRRR_ml_output_linear",
    ]
    missing_info = {}

    for dir_name in os.listdir(base_path):
        dir_path = os.path.join(base_path, dir_name)
        if os.path.isdir(dir_path):  # Ensure it's a directory
            files_in_dir = set(os.listdir(dir_path))
            missing_files = [
                f for f in required_files if not any(f in file for file in files_in_dir)
            ]

            if missing_files:
                missing_info[dir_name] = missing_files

    if missing_info:
        print("Directories with missing files:")
        for directory, files in missing_info.items():
            print(f"{directory}: Missing {', '.join(files)}")

In [None]:
check_missing_files(
    "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/hrrr_prospectus"
)

In [None]:
nysms = os.listdir(
    "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/hrrr_prospectus"
)
len(nysms)

In [None]:
nysm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
stations = nysm_clim["stid"].unique()

In [None]:
# nysm_clim[nysm_clim['climate_division_name']=='Northern Plateau']

In [None]:
for s in stations:
    if not s in nysms:
        print("Station not yet formulated... ", s)

In [None]:
def clim_div_filter(c):
    nysm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
    df = nysm_clim[nysm_clim["climate_division_name"] == c]
    stations = df["stid"].unique()
    stations = ["VOOR", "BUFF"]
    return stations

In [None]:
# hudson_v = clim_div_filter('Hudson Valley')
# coast = clim_div_filter('Coastal')
# st_ = clim_div_filter('St. Lawrence Valley')
# greats = clim_div_filter('Great Lakes')
# west = clim_div_filter('Western Plateau')
# north = clim_div_filter("Northern Plateau")
# champ = clim_div_filter('Champlain Valley')
# mohawk = clim_div_filter('Mohawk Valley')
# central = clim_div_filter("Central Lakes")
# east = clim_div_filter('Eastern Plateau')


# base_dir = "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/2025035"
# dirs = os.listdir(base_dir)


# for d in dirs:
#     dir_path = os.path.join(base_dir, d)
#     if d in st_ or d in greats:
#         continue
#     else:
#         if os.path.isdir(dir_path):  # Ensure it's a directory before removing
#             shutil.rmtree(dir_path)

In [None]:
def move_directories(base_dir, merge_dir, clim_div):
    """
    Moves files from directories in `base_dir` to corresponding directories in `merge_dir`
    if the directory name is in `hudson_v` or `coast`.

    Args:
        base_dir (str): Path to the directory containing the source data.
        merge_dir (str): Path to the directory where data should be moved.
        hudson_v (list): List of valid Hudson Valley directories.
        coast (list): List of valid Coastal directories.
    """
    dirs = os.listdir(base_dir)

    for d in dirs:
        if d in clim_div:
            src_path = os.path.join(base_dir, d)
            dest_path = os.path.join(merge_dir, d)

            if os.path.exists(src_path):
                os.makedirs(
                    dest_path, exist_ok=True
                )  # Ensure destination directory exists

                for file in os.listdir(src_path):
                    src_file = os.path.join(src_path, file)
                    dest_file = os.path.join(dest_path, file)

                    if os.path.isfile(src_file):
                        shutil.move(src_file, dest_file)  # Move the file

In [None]:
base_dir = (
    "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/nam_prospectus/"
)

bases = os.listdir(base_dir)

for d in bases:
    dir_path = os.path.join(base_dir, d)
    if os.path.isdir(dir_path):  # Ensure it's a directory
        files = os.listdir(dir_path)
        for f in files:
            if "full" in f and f.endswith(
                ".parquet"
            ):  # Check for 'full' and '.parquet'
                file_path = os.path.join(dir_path, f)
                os.remove(file_path)  # Delete the matching parquet file
                print(f"Deleted: {file_path}")  # Print confirmation

In [None]:
base = "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/20250413"
target = "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/nam_prospectus/alpha3"
# move_directories(base, target, champ)
# move_directories(base, target, north)
# move_directories(base, target, champ)
# move_directories(base, target, mohawk)

In [None]:
error_df = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/error_visuals/Coastal/Coastal_t2m_error_metrics_master.parquet"
)
error_df

In [None]:
oksm = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/oksm/hrrr_data/fh01/HRRR_2018_01_direct_compare_to_oksm_sites_mask_water.parquet"
)
oksm

In [None]:
radiometer_df = pd.read_csv(
    "/home/aevans/nwp_bias/src/machine_learning/notebooks/data/radiometer_network.csv"
)
radiometer_df

In [None]:
img = "/home/aevans/nwp_bias/src/machine_learning/data/profiler_images/2018/PROF_ALBA/PROF_ALBA_2018_010100.npy"

In [None]:
image = np.load(img).astype(np.float32)

In [None]:
nysm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
nysm_clim

In [None]:
gfs = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/gfs_data/fh009/GFS_2018_08_direct_compare_to_nysm_sites_mask_water.parquet"
)

In [None]:
gfs.columns

In [None]:
df = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/lstm_eval_csvs/20250107/VOOR/01_07_2025_14:58:37_full_VOOR.parquet"
)
for k in df.columns:
    print(k)

In [None]:
df["diff"] = df[f"tp_VOOR"] - df[f"precip_total_VOOR"]

In [None]:
def date_filter(ldf, time1, time2):
    ldf = ldf[ldf["valid_time"] > time1]
    ldf = ldf[ldf["valid_time"] < time2]

    return ldf

In [None]:
time1 = datetime(2024, 1, 1, 0, 0, 0)
time2 = datetime(2024, 1, 31, 23, 0, 0)

df = date_filter(df, time1, time2)

In [None]:
def met_output(df, station, fh):
    fig, ax = plt.subplots(figsize=(24, 6))
    x = df["valid_time"]

    # Convert datetime values to numerical values
    x_numeric = mdates.date2num(x)

    # Assuming your timestamps are in a datetime64 format
    day_mask = (x.dt.hour >= 6) & (
        x.dt.hour < 18
    )  # Adjust the hours based on your day/night definition

    plt.plot(
        np.array(x),
        np.array(df[f"u_total_{station}"]),
        c="mediumseagreen",
        linewidth=3,
        label="NAM Prediction",
    )

    plt.plot(
        np.array(x),
        np.array(df[f"wspd_sonic_mean_{station}"]),
        c="black",
        linewidth=1,
        alpha=0.9,
        label="NYSM Observation",
    )

    # Fill daytime hours with white color
    ax.fill_between(
        x_numeric, 0, 10, where=day_mask, color="white", alpha=0.5, label="Daytime"
    )

    # Fill nighttime hours with grey color
    ax.fill_between(
        x_numeric, 0, 10, where=~day_mask, color="grey", alpha=0.2, label="Nighttime"
    )

    ax.set_title(f"NAM Prediction v NYSM Observation: {station}: FH{fh}", fontsize=28)
    # plt.ylim(-5, 5.)
    ax.legend()
    plt.show()

In [None]:
met_output(df, "VOOR", 1)

In [None]:
nysm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
clim_div = nysm_clim["climate_division_name"].unique()
nysm_clim

In [None]:
# nysm_clim = nysm_clim[nysm_clim["climate_division_name"] == "Hudson Valley"]
# nysm_clim

In [None]:
df1 = pd.read_csv(
    "/home/aevans/nwp_bias/src/machine_learning/data/parent_models/HRRR/s2s/Central Lakes_u_total_HRRR_lookup_quad.csv"
)
# df = df[df["station"] == "ADDI"]
df1

In [None]:
nysm_clim = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")
nysm_clim[nysm_clim["climate_division_name"] == "Eastern Plateau"]

In [None]:
def get_more_fh(fh, station, var, times):
    hrrr_df_0 = hrrr_data.read_hrrr_data(str(fh + 2).zfill(2))
    hrrr_df_1 = hrrr_data.read_hrrr_data(str(fh + 4).zfill(2))

    hrrr_df_0 = hrrr_df_0[hrrr_df_0["station"] == station]
    hrrr_df_1 = hrrr_df_1[hrrr_df_1["station"] == station]

    hrrr_df_0 = hrrr_df_0[["valid_time", var]]
    hrrr_df_1 = hrrr_df_1[["valid_time", var]]

    # Create a DataFrame for valid times
    df = pd.DataFrame({"valid_time": times})
    df = df.merge(hrrr_df_0, on="valid_time", suffixes=(None, f"_{station}_+2"))
    df = df.merge(hrrr_df_1, on="valid_time", suffixes=(None, f"_{station}_+4"))
    df = df.rename(columns={"t2m": f"{var}_{station}_+2"})
    # df.fillna(-999, inplace=True)

    fh2 = df[f"{var}_{station}_+2"].values
    fh4 = df[f"{var}_{station}_+4"].values

    print(len(fh2))
    print(len(fh4))

    return fh2, fh4

In [None]:
def read_nam_data(fh):
    """
    Reads and concatenates parquet files containing forecast and error data for HRRR weather models
    for the years 2018 to 2022.

    Returns:
        pandas.DataFrame: of hrrr weather forecast information for each NYSM site.
    """

    years = ["2022", "2023", "2024"]
    savedir = f"/home/aevans/nwp_bias/src/machine_learning/data/nam_data/fh{fh}/"

    # create empty lists to hold dataframes for each model
    nam_fcast_and_error = []

    # loop over years and read in parquet files for each model
    for year in years:
        for month in np.arange(1, 13):
            str_month = str(month).zfill(2)
            if (
                os.path.exists(
                    f"{savedir}NAM_{year}_{str_month}_direct_compare_to_nysm_sites_mask_water.parquet"
                )
                == True
            ):
                print(
                    f"{savedir}NAM_{year}_{str_month}_direct_compare_to_nysm_sites_mask_water.parquet"
                )
                nam_fcast_and_error.append(
                    pd.read_parquet(
                        f"{savedir}NAM_{year}_{str_month}_direct_compare_to_nysm_sites_mask_water.parquet"
                    )
                )
            else:
                continue
            gc.collect()

    # concatenate dataframes for each model
    nam_fcast_and_error_df = pd.concat(nam_fcast_and_error)
    nam_fcast_and_error_df = nam_fcast_and_error_df.dropna()

    # return dataframes for each model
    return nam_fcast_and_error_df

In [None]:
def read_gfs_data(fh):
    """
    Reads and concatenates parquet files containing forecast and error data for HRRR weather models
    for the years 2018 to 2022.

    Returns:
        pandas.DataFrame: of hrrr weather forecast information for each NYSM site.
    """

    years = ["2018", "2019", "2020", "2021", "2022", "2023"]
    savedir = f"/home/aevans/nwp_bias/src/machine_learning/data/gfs_data/fh{fh}/"

    # create empty lists to hold dataframes for each model
    gfs_fcast_and_error = []

    # loop over years and read in parquet files for each model
    for year in years:
        print("compiling", year)
        for month in np.arange(1, 13):
            print(month)
            str_month = str(month).zfill(2)
            if (
                os.path.exists(
                    f"{savedir}GFS_{year}_{str_month}_direct_compare_to_nysm_sites_mask_water.parquet"
                )
                == True
            ):
                gfs_fcast_and_error.append(
                    pd.read_parquet(
                        f"{savedir}GFS_{year}_{str_month}_direct_compare_to_nysm_sites_mask_water.parquet"
                    )
                )
            else:
                continue

    # concatenate dataframes for each model
    gfs_fcast_and_error_df = pd.concat(gfs_fcast_and_error)

    # return dataframes for each model
    return gfs_fcast_and_error_df

In [None]:
def load_nysm_data():
    """
    Load and concatenate NYSM (New York State Mesonet) data from parquet files.

    NYSM data is resampled at 1-hour intervals and stored in separate parquet files
    for each year from 2018 to 2022.

    Returns:
        nysm_1H_obs (pd.DataFrame): A DataFrame containing concatenated NYSM data with
        missing values filled for the 'snow_depth' column.

    This function reads NYSM data from parquet files, resamples it to a 1-hour interval,
    and concatenates the data from multiple years. Missing values in the 'snow_depth'
    column are filled with -999, and any rows with missing values are dropped before
    returning the resulting DataFrame.

    Example:
    ```
    nysm_data = load_nysm_data()
    print(nysm_data.head())
    ```

    Note: Ensure that the parquet files are located in the specified path before using this function.
    """
    # Define the path where NYSM parquet files are stored.
    nysm_path = "/home/aevans/nwp_bias/data/nysm/"

    # Initialize an empty list to store data for each year.
    nysm_1H = []

    # Loop through the years from 2018 to 2022 and read the corresponding parquet files.
    for year in np.arange(2018, 2025):
        df = pd.read_parquet(f"{nysm_path}nysm_1H_obs_{year}.parquet")
        df.reset_index(inplace=True)
        nysm_1H.append(df)

    # Concatenate data from different years into a single DataFrame.
    nysm_1H_obs = pd.concat(nysm_1H)

    # Fill missing values in the 'snow_depth' column with -999.
    nysm_1H_obs["snow_depth"].fillna(-999, inplace=True)
    # Fill missing values in the 'snow_depth' column with -999.
    nysm_1H_obs["ta9m"].fillna(-999, inplace=True)

    # if nysm_1H_obs['ta9m'].isna().mean() > 0.8:
    #     nysm_1H_obs.drop('ta9m', axis=1, inplace=True)

    # nysm_1H_obs.dropna(inplace=True)

    return nysm_1H_obs

In [None]:
df = load_nysm_data()

# df = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/nysm.csv")

# stations_ls = ['MANH', 'VOOR', 'HERK', 'ANDE', 'BUFF', 'SCIP', 'GROV', 'LOUI', 'ESSX', 'GABR']

In [None]:
df = df[df["station"] == "TUPP"]

In [None]:
df

In [None]:
gfs = read_gfs_data("009")

In [None]:
gfs

In [None]:
df

In [None]:
# df.to_csv("/home/aevans/nwp_bias/src/landtype/data/first_paper_stations_coords.csv")

In [None]:
# import gc

# gfs_df = read_gfs_data("006")

In [None]:
# gfs_df["station"].unique()

In [None]:
fh = 6
station = "SOUT"
var = "t2m"

In [None]:
df = pd.read_parquet(
    "/home/aevans/nwp_bias/src/machine_learning/data/nam_data/fh001/NAM_2022_04_direct_compare_to_nysm_sites_mask_water.parquet"
)
df

In [None]:
nam_df = read_nam_data(str(fh).zfill(3))

In [None]:
# hrrr_df = hrrr_data.read_hrrr_data(str(fh).zfill(2))

# # Filter NYSM data to match valid times from HRRR data
# mytimes = hrrr_df["valid_time"].tolist()
# fh2_, fh4_ = get_more_fh(fh, station, var, mytimes)

In [None]:
# len(mytimes)

In [None]:
a100_mae = [
    0.07,
    0.17,
]
a100_mse = [
    0.07,
    0.22,
]
a100_batch = [
    1000,
    5000,
]
a100_gpu = [8, 30]
a100_runtime = [
    timedelta(seconds=24, minutes=16, hours=0),
    timedelta(seconds=5, minutes=16, hours=0),
]

In [None]:
gh200_mae = [0.06, 0.06]
gh200_mse = [0.06, 0.07]
gh200_batch = [1000, 10000]
gh200_gpu = [8, 64]
gh200_runtime = [
    timedelta(seconds=22, minutes=6, hours=0),
    timedelta(seconds=51, minutes=6, hours=0),
]

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from datetime import timedelta


def plot_runtime_bar_chart(a100_batch, a100_run_time):
    # Convert timedelta objects to total minutes
    run_time_minutes = [rt.total_seconds() / 60 for rt in a100_run_time]

    # Create the plot
    fig, ax = plt.subplots()

    # Plot the bar chart
    ax.bar(a100_batch, run_time_minutes, 1000, color="orange", label="Run Time")

    # Adding scatter points with large X markers on top of bars
    # ax.scatter(a100_batch, run_time_minutes, color='red', marker='x', s=100, label='Run Time Points')

    # Adding labels and title
    ax.set_xlabel("Batch Size")
    ax.set_ylabel("Run Time (minutes)")
    ax.set_title("Run Time by Batch Size gh200")

    # Display the plot
    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def plot_metrics_bar(a100_mae, a100_mse, a100_batch):
    # Number of bars
    n = len(a100_mae)

    # Create an array for the positions of the bars
    bar_width = 0.35
    index = np.arange(n)

    # Plotting the bars
    fig, ax = plt.subplots()
    bar1 = ax.bar(
        index,
        a100_mae,
        bar_width,
    )
    # bar2 = ax.bar(index + bar_width, a100_mse, bar_width, label='MSE')

    # Adding labels and title
    ax.set_xlabel("Batch Size")
    ax.set_ylabel("GPU Memory")
    ax.set_ylim(0, 90)
    ax.set_title("GPU Memory by Batch Size for a100")
    ax.set_xticks(index)
    ax.set_xticklabels(a100_batch)
    ax.legend()

    # Display the plot
    plt.show()

In [None]:
plot_runtime_bar_chart(a100_batch, a100_runtime)

In [None]:
plot_runtime_bar_chart(gh200_batch, gh200_runtime)

In [None]:
plot_metrics_bar(a100_gpu, a100_mse, a100_batch)

In [None]:
plot_metrics_bar(gh200_gpu, gh200_mse, gh200_batch)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import cartopy.crs as crs
import cartopy.feature as cfeature

color_dict = {
    0: "cyan",
    1: "blue",
    2: "yellow",
    3: "green",
    # 4: 'red',
    # 5: 'orange',
    # 6: 'purple',
    # 7: 'black',
    # 8: 'white'
}


def plurality_plot(df, geovar):
    projPC = crs.PlateCarree()
    latN = df["lat"].max() + 1
    latS = df["lat"].min() - 1
    lonW = df["lon"].max() + 1
    lonE = df["lon"].min() - 1
    cLat = (latN + latS) / 2
    cLon = (lonW + lonE) / 2
    projLcc = crs.LambertConformal(central_longitude=cLon, central_latitude=cLat)

    fig, ax = plt.subplots(
        figsize=(6, 6), subplot_kw={"projection": crs.PlateCarree()}, dpi=400
    )
    ax.set_extent([lonW, lonE, latS, latN], crs=projPC)
    ax.add_feature(cfeature.LAND)
    ax.add_feature(cfeature.COASTLINE)
    ax.add_feature(cfeature.BORDERS, linestyle="--")
    ax.add_feature(cfeature.LAKES, alpha=0.5)
    ax.add_feature(cfeature.STATES)
    ax.xticklabels_top = False
    ax.ylabels_right = False
    ax.gridlines(
        crs=crs.PlateCarree(),
        draw_labels=True,
        linewidth=2,
        color="black",
        alpha=0.5,
        linestyle="--",
    )
    ax.scatter(
        x=df["lon"],
        y=df["lat"],
        c=df["color"],
        s=40,
        marker="o",
        edgecolor="black",
        transform=crs.PlateCarree(),
    )
    ax.set_title(f"Mesonet Site {geovar} Clusters", size=16)
    ax.set_xlabel("Longitude", size=14)
    ax.set_ylabel("Latitude", size=14)
    ax.tick_params(axis="x", labelsize=12)
    ax.tick_params(axis="y", labelsize=12)
    ax.grid()

    # Create legend patches
    legend_patches = [
        mpatches.Patch(color=color, label=f"Category {key}")
        for key, color in color_dict.items()
    ]

    # Add the legend to the plot
    ax.legend(
        handles=legend_patches,
        loc="upper left",  # Use 'upper left' to anchor the legend in the figure
        bbox_to_anchor=(1.1, 1),  # Move the legend outside the plot to the right
        borderaxespad=0,  # Adjust the padding between the legend and the axes
        title="Categories",
    )

    plt.show()

In [None]:
cluster_df = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/lstm_clusters.csv")
# cluster_df["lon"] = lons
# cluster_df["lat"] = lats
# cluster_df["color"] = cluster_df["elev_cat"].map(color_dict)
cluster_df

In [None]:
import geopandas as gpd
import matplotlib.pyplot as plt
import cartopy.crs as crs
import cartopy.feature as cfeature
import numpy as np
import pandas as pd
import random
import matplotlib.patches as mpatches
import matplotlib.image as mpimg
from matplotlib.colors import ListedColormap

In [None]:
def plot(shapefile_path):
    # Define a list of colors for each shape file
    # colors = list(mcolors.TABLEAU_COLORS.values())
    # projPC = crs.PlateCarree()
    # latN = df["nysm_lat"].max() + 0.5
    # latS = df["nysm_lat"].min() - 0.5
    # lonW = df["nysm_lon"].max() + 0.5
    # lonE = df["nysm_lon"].min() - 0.5
    # cLat = (latN + latS) / 2
    # cLon = (lonW + lonE) / 2
    # projLcc = crs.LambertConformal(central_longitude=cLon, central_latitude=cLat)

    fig, ax = plt.subplots(
        figsize=(9, 15), subplot_kw={"projection": crs.PlateCarree()}
    )
    # ax.legend()
    # # ax.set_extent([lonW, lonE, latS, latN], crs=projPC)
    # ax.add_feature(cfeature.LAND)
    # ax.add_feature(cfeature.COASTLINE)
    # ax.add_feature(cfeature.BORDERS, linestyle="--")
    # ax.add_feature(cfeature.LAKES, alpha=0.5)
    # ax.add_feature(cfeature.STATES)
    # ax.xticklabels_top = False
    # ax.ylabels_right = False
    # ax.gridlines(
    #     crs=crs.PlateCarree(),
    #     draw_labels=True,
    #     linewidth=2,
    #     color="black",
    #     alpha=0.5,
    #     linestyle="--",
    # )

    # plt.scatter(
    #     df["nysm_lon"],
    #     df["nysm_lat"],
    #     c="blue",
    #     s=70,
    #     edgecolors="black",
    #     transform=crs.PlateCarree(),
    #     zorder=5,
    #     label="NYSM Sites",
    # )

    # plt.scatter(
    #     df["nysm_lon"].iloc[0],
    #     df["nysm_lat"].iloc[0],
    #     c="green",
    #     marker="*",
    #     s=400,
    #     edgecolors="black",
    #     transform=crs.PlateCarree(),
    #     zorder=5,
    #     label="Southold",
    # )

    # plt.scatter(
    #     df["hrrr_lon"],
    #     df["hrrr_lat"],
    #     c='orange',
    #     s = 70,
    #     edgecolors='black',
    #     transform=crs.PlateCarree(),
    #     zorder=5,
    #     label='HRRR'
    # )

    # # Annotate each point in NYSM
    # for i, txt in enumerate(df["station"]):
    #     plt.annotate(
    #         txt,
    #         (df["nysm_lon"].iloc[i], df["nysm_lat"].iloc[i]),
    #         textcoords="offset points",
    #         xytext=(5, 10),
    #         ha="center",
    #         fontsize=18,
    #     )

    # Load the shape file using geopandas
    climate_divisions = gpd.read_file(shapefile_path)
    # Plot climate divisions from the shape file
    climate_divisions.plot(
        ax=ax,
        edgecolor="black",
        facecolor="none",
        transform=crs.PlateCarree(),
        zorder=4,
    )

    plt.legend(bbox_to_anchor=(1.1, 1), loc="upper left", borderaxespad=0, fontsize=12)

In [None]:
okdf = pd.read_csv("/home/aevans/nwp_bias/src/landtype/data/oksm.csv")
okdf = okdf[okdf["datd"] == 20991231]
okdf

In [None]:
path = "/home/aevans/nwp_bias/src/machine_learning/notebooks/data/GIS.OFFICIAL_CLIM_DIVISIONS.shp"
gdf = gpd.read_file(path)
gdf_filtered = pd.concat([gdf.iloc[191:198], gdf.iloc[172:175]])
gdf_filtered["category"] = np.arange(len(gdf_filtered))

In [None]:
from shapely.geometry import Point

# Create Point geometries
geometry = [Point(xy) for xy in zip(okdf["elon"], okdf["nlat"])]

# Convert to GeoDataFrame
points_gdf = gpd.GeoDataFrame(okdf, geometry=geometry, crs="EPSG:4326")

In [None]:
gdf_filtered

In [None]:
clim_div = [
    "Northeast",
    "West Central",
    "Central",
    "East Central",
    "Southwest",
    "South Central",
    "Southeast",
    "Panhandle",
    "North Central",
    "Nan",
    "Nan1",
    "Nan2",
    "Nan3",
]

In [None]:
logo = "/home/aevans/nwp_bias/src/landtype/data/NCEI_logo.png"

In [None]:
# Create plot
fig = plt.figure(figsize=(24, 16))
ax = fig.add_subplot(
    1,
    1,
    1,
    projection=crs.LambertConformal(
        central_longitude=-98.0, standard_parallels=(30, 40)
    ),
)

# Create legend for climate divisions using the colors from the 'tab10' colormap and labels from 'clim_div'
division_patches = [
    mpatches.Patch(
        color=plt.cm.tab10(i / len(gdf_filtered)), alpha=0.3, label=clim_div[i]
    )
    for i in range(len(gdf_filtered) - 1)
]

# Add the climate divisions legend
legend1 = ax.legend(
    handles=division_patches,
    loc="upper right",
    title="Climate Divisions",
    fontsize=12,
)
ax.add_artist(legend1)  # Ensure the first legend is added to the plot
# Set extent for the plot
ax.set_extent([-103.0, -94.0, 33.0, 38.0], crs=crs.PlateCarree())
# Add features
ax.add_feature(cfeature.BORDERS.with_scale("50m"), linestyle=":", zorder=1)
ax.add_feature(cfeature.STATES.with_scale("50m"), linestyle=":", zorder=1)
ax.add_feature(cfeature.LAKES.with_scale("50m"), zorder=1)
ax.gridlines(
    crs=crs.PlateCarree(),
    draw_labels=True,
    linewidth=2,
    color="black",
    alpha=0.5,
    linestyle="--",
)
gdf_filtered.plot(
    ax=ax,
    transform=crs.PlateCarree(),
    column="category",
    cmap="tab10",
    alpha=0.3,
    legend=False,
)

# Annotate scatter points with station IDs
for i, row in okdf.iterrows():
    ax.annotate(
        row["stid"],
        (row["elon"], row["nlat"]),
        textcoords="offset points",
        xytext=(0, 7),
        ha="center",
        fontsize=12,
        color="black",
        transform=crs.PlateCarree(),
    )

# Plot scatter points
ax.scatter(
    okdf["elon"],
    okdf["nlat"],
    c="black",
    s=250,
    edgecolors="black",
    transform=crs.PlateCarree(),
    zorder=10,
)

# Add plot title
plt.title(
    f"NCEI Oklahoma State Climate Divisions",
    fontsize=24,
)
# Load and add the logo to the lower left
logo_img = mpimg.imread(logo)
ax.figure.figimage(
    logo_img, 50, 50, zorder=20, alpha=0.5
)  # Adjust (x, y) position as needed

plt.savefig(f"/home/aevans/nwp_bias/src/landtype/data/OK_state_clim_div.png")
# Show plot
plt.show()

In [None]:
nysm_clim

In [None]:
clim_div = [
    "St. Lawrence Valley",
    "Great Lakes",
    "Northern Plateau",
    "Champlain Valley",
    "Hudson Valley",
    "Mohawk Valley",
    "Western Plateau",
    "Eastern Pleateau",
    "Coastal",
    "Central Lakes",
]

In [None]:
# clim_div = sorted(clim_div)
image = "/home/aevans/nwp_bias/src/landtype/data/NCEI_logo.png"

In [None]:
def create_xCITE_gif(nysm_clim, fh, clim_div, logo):
    # Create your dataframe df_
    df_ = nysm_clim.copy()

    # Define colors dictionary and randomly assign colors
    nwp_dict = {0: "green", 1: "red", 2: "blue"}  # NAM  # HRRR  # GFS
    nwps_all = [0, 1, 2]

    # Randomly assign values from nwps_all to the 'lister'
    lister = [random.choice(nwps_all) for _ in df_["stid"]]
    df_["color"] = [nwp_dict[value] for value in lister]

    # Create plot
    fig = plt.figure(figsize=(24, 16))
    ax = fig.add_subplot(
        1,
        1,
        1,
        projection=crs.LambertConformal(
            central_longitude=-75.0, standard_parallels=(49, 77)
        ),
    )

    # Load the shapefile for boundaries
    shapefile_path = "/home/aevans/nwp_bias/src/machine_learning/notebooks/data/GIS.OFFICIAL_CLIM_DIVISIONS.shp"
    gdf = gpd.read_file(shapefile_path)

    ny_state_boundaries_path = "/home/aevans/nwp_bias/src/landtype/data/State.shx"
    ny_state_boundaries_geo = gpd.read_file(ny_state_boundaries_path).to_crs(epsg=4326)

    ny_bbox = ny_state_boundaries_geo.total_bounds
    gdf_filtered = gdf.cx[ny_bbox[0] : ny_bbox[2], ny_bbox[1] : ny_bbox[3]]
    gdf_filtered = pd.concat([gdf_filtered.iloc[20:29], gdf_filtered.iloc[[32]]])

    # Create a categorical column for plotting
    gdf_filtered["category"] = np.arange(len(gdf_filtered))

    # Plot shapefile with climate divisions (remove the automatic legend)
    gdf_filtered.plot(
        ax=ax,
        transform=crs.PlateCarree(),
        column="category",
        cmap="tab10",
        alpha=0.3,
        legend=False,
    )

    # Create legend for climate divisions using the colors from the 'tab10' colormap and labels from 'clim_div'
    division_patches = [
        mpatches.Patch(
            color=plt.cm.tab10(i / len(gdf_filtered)), alpha=0.3, label=clim_div[i]
        )
        for i in range(len(gdf_filtered))
    ]

    # Add the climate divisions legend
    legend1 = ax.legend(
        handles=division_patches,
        loc="upper right",
        title="Climate Divisions",
        fontsize=12,
    )
    ax.add_artist(legend1)  # Ensure the first legend is added to the plot

    # Set extent for the plot
    ax.set_extent([-80.0, -72.0, 40.0, 45.5], crs=crs.PlateCarree())

    # Add features
    ax.add_feature(cfeature.BORDERS.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.LAKES.with_scale("50m"), zorder=1)
    ax.gridlines(
        crs=crs.PlateCarree(),
        draw_labels=True,
        linewidth=2,
        color="black",
        alpha=0.5,
        linestyle="--",
    )

    # Annotate scatter points with station IDs
    for i, row in df_.iterrows():
        ax.annotate(
            row["stid"],
            (row["lon [degrees]"], row["lat [degrees]"]),
            textcoords="offset points",
            xytext=(0, 7),
            ha="center",
            fontsize=12,
            color="black",
            transform=crs.PlateCarree(),
        )

    # Plot scatter points
    ax.scatter(
        df_["lon [degrees]"],
        df_["lat [degrees]"],
        c="black",
        s=250,
        edgecolors="black",
        transform=crs.PlateCarree(),
        zorder=10,
    )

    # # Create custom legend for NWP models
    # nam_patch = mpatches.Patch(color="green", label="NAM")
    # hrrr_patch = mpatches.Patch(color="red", label="HRRR")
    # gfs_patch = mpatches.Patch(color="blue", label="GFS")

    # # Add second legend to the plot
    # ax.legend(
    #     handles=[nam_patch, hrrr_patch, gfs_patch],
    #     loc="upper left",
    #     fontsize=12,
    #     title="NWP Models",
    # )

    # Add plot title
    plt.title(
        f"NCEI New York State Climate Divisions",
        fontsize=24,
    )
    # Load and add the logo to the lower left
    logo_img = mpimg.imread(logo)
    ax.figure.figimage(
        logo_img, 50, 50, zorder=20, alpha=0.5
    )  # Adjust (x, y) position as needed

    plt.savefig(f"/home/aevans/nwp_bias/src/landtype/data/NY_state_clim_div.png")
    # Show plot
    plt.show()

In [None]:
create_xCITE_gif(nysm_clim, 1, clim_div, image)

In [None]:
nysm_clim
nysm_ = nysm_clim.copy()

nysm_ = nysm_.rename(columns={"lat [degrees]": "lat", "lon [degrees]": "lon"})
nysm_

In [None]:
import pandas as pd
import numpy as np
from sklearn.neighbors import BallTree
from sklearn import preprocessing
from sklearn import utils


def get_closest_stations(nysm_df, neighbors, target_station, nwp_model):
    # Earth's radius in kilometers
    EARTH_RADIUS_KM = 6378

    lats = nysm_df["lat"].unique()
    lons = nysm_df["lon"].unique()

    locations_a = pd.DataFrame()
    locations_a["lat"] = lats
    locations_a["lon"] = lons

    for column in locations_a[["lat", "lon"]]:
        rad = np.deg2rad(locations_a[column].values)
        locations_a[f"{column}_rad"] = rad

    locations_b = locations_a

    ball = BallTree(locations_a[["lat_rad", "lon_rad"]].values, metric="haversine")

    # k: The number of neighbors to return from tree
    k = neighbors
    # Executes a query with the second group. This will also return two arrays.
    distances, indices = ball.query(locations_b[["lat_rad", "lon_rad"]].values, k=k)

    # Convert distances from radians to kilometers
    distances_km = distances * EARTH_RADIUS_KM

    # source info to creare a dictionary
    indices_list = [indices[x][0:k] for x in range(len(indices))]
    distances_list = [distances_km[x][0:k] for x in range(len(distances_km))]
    stations = nysm_df["stid"].unique()

    # create dictionary
    station_dict = {}
    for k, _ in enumerate(stations):
        station_dict[stations[k]] = (indices_list[k], distances_list[k])

    utilize_ls = []
    vals, dists = station_dict.get(target_station)

    if nwp_model == "GFS":
        utilize_ls.append(target_station)
        for v, d in zip(vals, dists):
            if d >= 30 and len(utilize_ls) < 5:
                x = stations[v]
                utilize_ls.append(x)

    if nwp_model == "NAM":
        utilize_ls.append(target_station)
        for v, d in zip(vals, dists):
            if d >= 12 and len(utilize_ls) < 4:
                x = stations[v]
                utilize_ls.append(x)

    if nwp_model == "HRRR":
        for v, d in zip(vals, dists):
            x = stations[v]
            utilize_ls.append(x)

    return utilize_ls

In [None]:
stations = nysm_clim["stid"].unique()

elev_delta = []

for s in stations:
    print(s)
    utilize_ls = get_closest_stations(nysm_, 15, s, "GFS")
    selection = nysm_[nysm_["stid"].isin(utilize_ls)]
    # Find the maximum value in 'col1'
    max_value = selection["elevation [m]"].max()
    min_value = selection["elevation [m]"].min()
    delta = max_value - min_value
    elev_delta.append(delta)
elev_delta

In [None]:
def create_gfs_learners(nysm_clim, clim_div, learners):
    """
    Create a GIF frame showing NWP bias correction with stations colored by their inclusion in the learners list.

    Parameters:
    - nysm_clim: DataFrame containing station data.
    - fh: Forecast hour.
    - clim_div: List of climate division names.
    - learners: List of station IDs classified as learners.
    """
    df_ = nysm_clim.copy()

    # Define color mapping based on whether the station is in the learners list
    df_["color"] = ["green" if stid in learners else "black" for stid in df_["stid"]]

    # Create plot
    fig = plt.figure(figsize=(24, 16))
    ax = fig.add_subplot(
        1,
        1,
        1,
        projection=crs.LambertConformal(
            central_longitude=-75.0, standard_parallels=(49, 77)
        ),
    )

    # Load the shapefile for boundaries
    shapefile_path = "/home/aevans/nwp_bias/src/machine_learning/notebooks/data/GIS.OFFICIAL_CLIM_DIVISIONS.shp"
    gdf = gpd.read_file(shapefile_path)

    ny_state_boundaries_path = "/home/aevans/nwp_bias/src/landtype/data/State.shx"
    ny_state_boundaries_geo = gpd.read_file(ny_state_boundaries_path).to_crs(epsg=4326)

    ny_bbox = ny_state_boundaries_geo.total_bounds
    gdf_filtered = gdf.cx[ny_bbox[0] : ny_bbox[2], ny_bbox[1] : ny_bbox[3]]
    subset = pd.concat([gdf_filtered.iloc[20:29], gdf_filtered.iloc[[32]]])
    gdf_filtered = subset.copy()

    # Create a categorical column for plotting
    gdf_filtered["category"] = np.arange(len(gdf_filtered))

    # Plot shapefile with climate divisions (remove the automatic legend)
    gdf_filtered.plot(
        ax=ax,
        transform=crs.PlateCarree(),
        column="category",
        cmap="tab10",
        alpha=0.3,
        legend=False,
    )

    # Create legend for climate divisions using the colors from the 'tab10' colormap and labels from 'clim_div'
    division_patches = [
        mpatches.Patch(
            color=plt.cm.tab10(i / len(gdf_filtered)), alpha=0.3, label=clim_div[i]
        )
        for i in np.arange(0, len(gdf_filtered))
    ]

    # Add the climate divisions legend
    legend1 = ax.legend(
        handles=division_patches,
        loc="lower left",
        title="Climate Divisions",
        fontsize=18,
    )
    legend1.set_title(
        "Climate Divisions", prop={"size": 18}
    )  # Custom font size for the title
    ax.add_artist(legend1)  # Ensure the first legend is added to the plot

    # Set extent for the plot
    ax.set_extent([-80.0, -72.0, 40.0, 45.1], crs=crs.PlateCarree())

    # Add features
    ax.add_feature(cfeature.BORDERS.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.LAKES.with_scale("50m"), zorder=1)
    ax.gridlines(
        crs=crs.PlateCarree(),
        draw_labels=True,
        linewidth=2,
        color="black",
        alpha=0.5,
        linestyle="--",
    )

    # Annotate scatter points with station IDs
    for i, row in df_.iterrows():
        ax.annotate(
            row["stid"],
            (row["lon [degrees]"], row["lat [degrees]"]),
            textcoords="offset points",
            xytext=(0, 7),
            ha="center",
            fontsize=12,
            color="black",
            transform=crs.PlateCarree(),
        )

    # Plot scatter points
    ax.scatter(
        df_["lon [degrees]"],
        df_["lat [degrees]"],
        c=df_["color"],
        s=250,
        edgecolors="black",
        transform=crs.PlateCarree(),
        zorder=10,
    )

    # Create custom legend for learner status
    learner_patch = mpatches.Patch(color="green", label="Learner")
    non_learner_patch = mpatches.Patch(color="black", label="Non-Learner")

    # Add second legend to the plot
    legend = ax.legend(
        handles=[learner_patch, non_learner_patch],
        loc="upper left",
        fontsize=18,
        title="Station Classification",
    )
    legend.set_title(
        "Station Classification", prop={"size": 18}
    )  # Custom font size for the title

    # Add plot title
    plt.title(
        f"GFS T2M Error : NYSM Stations that Can Learn",
        fontsize=24,
    )

    # # Save the figure
    # plt.savefig(f"/home/aevans/nwp_bias/src/landtype/data/xCITE_gif/mockup_fh{fh}.png")
    # Show plot
    plt.show()

In [None]:
def create_gfs_learners_delta(nysm_clim, clim_div, learners, elev_delta):
    """
    Create a GIF frame showing NWP bias correction with stations colored by their inclusion in the learners list.

    Parameters:
    - nysm_clim: DataFrame containing station data.
    - clim_div: List of climate division names.
    - learners: List of station IDs classified as learners.
    - elev_delta: List containing the elevation delta values to determine scatter point size.
    """
    df_ = nysm_clim.copy()

    # Define color mapping based on whether the station is in the learners list
    df_["color"] = ["green" if stid in learners else "black" for stid in df_["stid"]]

    # Ensure elev_delta is the same length as df_
    if len(elev_delta) != len(df_):
        raise ValueError(
            "Length of elev_delta must match the number of stations in the DataFrame"
        )

    # Create plot
    fig = plt.figure(figsize=(24, 16))
    ax = fig.add_subplot(
        1,
        1,
        1,
        projection=crs.LambertConformal(
            central_longitude=-75.0, standard_parallels=(49, 77)
        ),
    )

    # Load the shapefile for boundaries
    shapefile_path = "/home/aevans/nwp_bias/src/machine_learning/notebooks/data/GIS.OFFICIAL_CLIM_DIVISIONS.shp"
    gdf = gpd.read_file(shapefile_path)

    ny_state_boundaries_path = "/home/aevans/nwp_bias/src/landtype/data/State.shx"
    ny_state_boundaries_geo = gpd.read_file(ny_state_boundaries_path).to_crs(epsg=4326)

    ny_bbox = ny_state_boundaries_geo.total_bounds
    gdf_filtered = gdf.cx[ny_bbox[0] : ny_bbox[2], ny_bbox[1] : ny_bbox[3]]
    subset = pd.concat([gdf_filtered.iloc[20:29], gdf_filtered.iloc[[32]]])
    gdf_filtered = subset.copy()

    # Create a categorical column for plotting
    gdf_filtered["category"] = np.arange(len(gdf_filtered))

    # Plot shapefile with climate divisions (remove the automatic legend)
    gdf_filtered.plot(
        ax=ax,
        transform=crs.PlateCarree(),
        column="category",
        cmap="tab10",
        alpha=0.3,
        legend=False,
    )

    # Create legend for climate divisions using the colors from the 'tab10' colormap and labels from 'clim_div'
    division_patches = [
        mpatches.Patch(
            color=plt.cm.tab10(i / len(gdf_filtered)), alpha=0.3, label=clim_div[i]
        )
        for i in np.arange(0, len(gdf_filtered))
    ]

    # Add the climate divisions legend
    legend1 = ax.legend(
        handles=division_patches,
        loc="lower left",
        title="Climate Divisions",
        fontsize=18,
    )
    legend1.set_title(
        "Climate Divisions", prop={"size": 18}
    )  # Custom font size for the title
    ax.add_artist(legend1)  # Ensure the first legend is added to the plot

    # Set extent for the plot
    ax.set_extent([-80.0, -72.0, 40.0, 45.1], crs=crs.PlateCarree())

    # Add features
    ax.add_feature(cfeature.BORDERS.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.STATES.with_scale("50m"), linestyle=":", zorder=1)
    ax.add_feature(cfeature.LAKES.with_scale("50m"), zorder=1)
    ax.gridlines(
        crs=crs.PlateCarree(),
        draw_labels=True,
        linewidth=2,
        color="black",
        alpha=0.5,
        linestyle="--",
    )

    # Annotate scatter points with station IDs
    for i, row in df_.iterrows():
        ax.annotate(
            row["stid"],
            (row["lon [degrees]"], row["lat [degrees]"]),
            textcoords="offset points",
            xytext=(0, 7),
            ha="center",
            fontsize=12,
            color="black",
            transform=crs.PlateCarree(),
        )

    # Plot scatter points with sizes based on 'elev_delta'
    ax.scatter(
        df_["lon [degrees]"],
        df_["lat [degrees]"],
        c=df_["color"],
        s=elev_delta,  # Size of scatter points based on elev_delta
        edgecolors="black",
        transform=crs.PlateCarree(),
        zorder=10,
    )

    # Create custom legend for learner status
    learner_patch = mpatches.Patch(color="green", label="Learner")
    non_learner_patch = mpatches.Patch(color="black", label="Non-Learner")

    # Add second legend to the plot
    legend = ax.legend(
        handles=[learner_patch, non_learner_patch],
        loc="upper left",
        fontsize=18,
        title="Station Classification",
    )
    legend.set_title(
        "Station Classification", prop={"size": 18}
    )  # Custom font size for the title

    # Add plot title
    plt.title(
        f"GFS T2M Error : NYSM Stations that Can Learn",
        fontsize=24,
    )

    # Show plot
    plt.show()

In [None]:
create_gfs_learners_delta(nysm_clim, clim_div, learners, elev_delta)