In [None]:
import xarray as xr
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from pyproj import Transformer, CRS, Proj

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

SELECTED_LOCATION = "Cordova"
YEAR = 1991 # picked to avoid any leap year stuff possibly confounding the analysis
SOURCE_DIR = Path(f"/beegfs/CMIP6/wrf_era5/04km/{YEAR}")

### change this to where your outputs are
PROCESSED_FILE = Path(f"/beegfs/CMIP6/cparr4/daily_downscaled_era5_for_rasdaman/t2_mean/t2_mean_{YEAR}_daily_era5_4km_3338.nc")

ak_locations = {
    "Anchorage": (61.2181, -149.9003),
    "Fairbanks": (64.8378, -147.7164),
    "Utqiaġvik": (71.2906, -156.7886),
    "Bethel": (60.7922, -161.7558),
    "Cordova": (60.5438, -145.7573),
    "Nome": (64.5011, -165.4064),
    "Seward": (60.1044, -149.4458),
    "WRF": (64.0, -152.0)
}

source_files = sorted(SOURCE_DIR.glob(f"era5_wrf_dscale_4km_{YEAR}-*.nc"))
ds_processed = xr.open_dataset(PROCESSED_FILE)

In [None]:
def find_nearest_grid_indices(ds, locations):
    """Find the nearest grid indices for a set of lat/lon locations."""
    ak_grid_indices = {}
    lats = ds['XLAT'].values
    lons = ds['XLONG'].values
    for name, (lat, lon) in locations.items():
        # Compute squared distance for all grid points
        dist2 = (lats - lat)**2 + (lons - lon)**2
        idx = np.unravel_index(np.argmin(dist2), lats.shape)
        ak_grid_indices[name] = {'south_north': idx[0], 'west_east': idx[1]}
    return ak_grid_indices

def project_locations(locations_lat_lon):
    """Project lat/lon coordinates to EPSG:3338."""
    to_ak_albers = Transformer.from_crs("EPSG:4326", "EPSG:3338", always_xy=True)
    projected_locs = {name: to_ak_albers.transform(lon, lat) for name, (lat, lon) in locations_lat_lon.items()}
    return projected_locs

ak_locations_3338 = project_locations(ak_locations)
ak_locations_3338

In [None]:
print("Finding nearest grid cell in the first source file...")
with xr.open_dataset(source_files[0]) as sample_ds:
    grid_indices = find_nearest_grid_indices(sample_ds, ak_locations)

source_loc_indices = grid_indices[SELECTED_LOCATION]
print(source_loc_indices)
print("get the **exact** lat-lon for that grid cell")
exact_lat = float(sample_ds['XLAT']
                  .isel(south_north=source_loc_indices['south_north'],
                       west_east=source_loc_indices['west_east']))
exact_lon = float(sample_ds['XLONG']
                  .isel(south_north=source_loc_indices['south_north'],
                       west_east=source_loc_indices['west_east']))
print(exact_lat, exact_lon)
print("notice that this is different than our initial input lat-lon!")
print(ak_locations[SELECTED_LOCATION])
print("project both to 3338, and compute the distance between the two points")
to_ak_albers = Transformer.from_crs("EPSG:4326", "EPSG:3338", always_xy=True)
x_exact_3338, y_exact_3338 = to_ak_albers.transform(exact_lon, exact_lat)
def distance_m_3338(p1, p2):
    """Euclidean distance between two EPSG:3338 points, in metres.

    Parameters
    ----------
    p1, p2
        Two-element sequences of *(x, y)* in metres (Alaska Albers).

    Returns
    -------
    float
        Distance in metres.
    """
    from math import hypot
    dx = p2[0] - p1[0]
    dy = p2[1] - p1[1]
    return hypot(dx, dy) 

dist_m  = distance_m_3338(ak_locations_3338[SELECTED_LOCATION], (x_exact_3338, y_exact_3338))
dist_km = dist_m / 1000
print(f"{dist_km:.1f} km")

In [None]:
print(f"Processing comparison for location: {SELECTED_LOCATION}")
# brute force search the surrounding grid cells up to NxN offset to find best match
# it should be in there somewhere!

# needed these 2x2 and 3x3 neighborhoods previously when our proj was off
#daily_means_dict = { (di,dj): [] for di in [-3,-2,-1,0,1,2,3] for dj in [-3,-2,-1,0,1,2,3] }
#daily_means_dict = { (di,dj): [] for di in [-2,-1,0,1,2] for dj in [-2,-1,0,1,2] }

daily_means_dict = { (di,dj): [] for di in [-1,0,1] for dj in [-1,0,1] }

print(f"Processing {len(source_files)} source files in a loop for multiple offsets...")
for f in source_files:
    # could use mf data open here, but this is fast enough
    with xr.open_dataset(f) as ds:
        for di,dj in daily_means_dict.keys():
            wn = source_loc_indices['west_east'] + di
            sn = source_loc_indices['south_north'] + dj
            
            wn = max(0, min(wn, ds.dims['west_east']-1))
            sn = max(0, min(sn, ds.dims['south_north']-1))
            
            source_raw = ds['T2'].isel(west_east=wn, south_north=sn)
            daily_mean = source_raw.resample(Time="1D").mean() - 273.15
            daily_means_dict[(di,dj)].append(daily_mean)
            
print("Combining daily means for each offset...")
offset_series = {}
for key, lst in daily_means_dict.items():
    series = xr.concat(lst, dim="Time").rename({'Time':'time'}).rename("t2_mean_source")
    offset_series[key] = series

print("Extracting data from processed file...")
processed_loc_coords = ak_locations_3338[SELECTED_LOCATION]
processed_daily_mean = ds_processed["t2_mean"].sel(
    ###
    # swap the commented out parts here to see the offset location with the min delta shift
    x=x_exact_3338,
    y=y_exact_3338,
    #x=processed_loc_coords[0],
    #y=processed_loc_coords[1],
    ###
    method="nearest"
)

print("\nOffset summary (mean absolute delta °C):")
delta_dict = {}
for key, src_series in offset_series.items():
    aligned_src, aligned_proc = xr.align(src_series, processed_daily_mean, join="inner")
    d = aligned_proc - aligned_src
    delta_dict[key] = float(np.abs(d).mean())
    print(f"  offset {key}: {delta_dict[key]:.2f}")

# pick best offset (minimum mean abs delta)
best_offset = min(delta_dict, key=delta_dict.get)
print(f"\nBest offset: {best_offset} with mean abs delta {delta_dict[best_offset]:.2f} °C")

# Use best offset for detailed plotting
aligned_source = offset_series[best_offset]
aligned_source, aligned_processed = xr.align(aligned_source, processed_daily_mean, join="inner")
delta = aligned_processed - aligned_source
delta = delta.rename("t2_mean_delta")

print("Generating plot for best offset...")

aligned_source = aligned_source.reset_coords(drop=True)
aligned_processed = aligned_processed.reset_coords(drop=True)
delta = delta.reset_coords(drop=True)

fig, axes = plt.subplots(nrows=3, ncols=1, figsize=(12, 10), sharex=True)

aligned_source.plot(ax=axes[0], label='Source (Daily Mean)')
axes[0].set_title(f'Source Daily Mean Temperature at {SELECTED_LOCATION}')
axes[0].set_ylabel('Temperature (°C)')
axes[0].grid(True)

aligned_processed.plot(ax=axes[1], color='orange', label='Processed (Daily Mean)')
axes[1].set_title(f'Processed Daily Mean Temperature at {SELECTED_LOCATION}')
axes[1].set_ylabel('Temperature (°C)')
axes[1].grid(True)

delta.plot(ax=axes[2], color='green', label='Delta (Processed - Source)')
axes[2].axhline(0, color='red', linestyle='--')
delta_min = float(delta.min())
delta_max = float(delta.max())
axes[2].set_ylim(delta_min, delta_max)

axes[2].set_title('Difference (Processed - Source)')
axes[2].set_ylabel('Temperature Delta (°C)')
axes[2].grid(True)

fig.suptitle(f'Temperature Comparison for {SELECTED_LOCATION} - {YEAR}', fontsize=16)
plt.tight_layout(rect=[0, 0.03, 1, 0.95])

output_fig_path = f"qc_comparison_{SELECTED_LOCATION}_{YEAR}.png"
plt.savefig(output_fig_path)
print(f"Plot saved to {output_fig_path}")
plt.close()

print(f"Mean Difference: {delta.mean().item():.2f} °C")
print(f"Max Difference: {delta.max().item():.2f} °C")
print(f"Min Difference: {delta.min().item():.2f} °C")