In [None]:
import pandas as pd
import geopandas as gpd
import numpy as np
from glob import glob
from pathlib import Path
import re
import warnings
import matplotlib.pyplot as plt
import rasterio as rio
import xarray as xr
import rioxarray as rxr
import sys
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import matplotlib.dates as mdates
import matplotlib.colors as mcolors
import matplotlib.patheffects as pe
import matplotlib.gridspec as gridspec
from matplotlib import patches
import matplotlib.ticker as mtick
sys.path.append('src')
from src.swe_retrievals import *
from src.plotting_functions import *

warnings.simplefilter('ignore', FutureWarning)
warnings.simplefilter('ignore', RuntimeWarning)
from pandas.errors import SettingWithCopyWarning
warnings.simplefilter('ignore', SettingWithCopyWarning)
warnings.simplefilter('ignore', UserWarning)

## Figure 1 - workflow example

In [None]:
error_type = 'soil'
fig, ax = plt.subplots(1, 2, figsize=(10,5))
paradise = plot_swe_curves(station_id=679, wateryear=2021, ax1=ax[0], ax2=ax[1], extra_title_text='\n(Deep snowpack)', return_data=True, plot_legends=False)
plt.close()

fig = plt.figure(figsize=(15, 5))

# Create a 2x3 grid
gs = gridspec.GridSpec(2, 3, figure=fig)
# Left column split into two rows
ax0 = fig.add_subplot(gs[0, 0])  # top-left
ax1 = fig.add_subplot(gs[1, 0])  # bottom-left

# Middle and right panels span both rows
ax2 = fig.add_subplot(gs[:, 1])  # middle
ax3 = fig.add_subplot(gs[:, 2])  # right

paradise[0]['soil_moisture_pct'].plot(ax=ax0, color='k')
paradise[0][f'{error_type}_error'].plot(ax=ax1, color='k')

for data in paradise[1].values():
    data.loc['2020-10-01':'2021-07-01',f'{error_type}_error_cumsum'].plot(ax=ax2, marker='o')
    
test = pd.concat([paradise[1][f'12d_{i}'][f'{error_type}_error_cumsum'] for i in range(12)]).sort_index()
test.loc['2020-10-01':'2021-07-01'].plot(ax=ax3, label='Daily data')
test.loc['2020-10-01':'2021-07-01'].rolling('12d').mean().plot(ax=ax3, label='Rolling 12d mean', lw=3)
leg = ax3.legend(fontsize=11, bbox_to_anchor=[0.98, 0.3], loc='upper right')
for legobj in leg.legend_handles:
    legobj.set_linewidth(2.0)

# paradise[0]['soil_error'].cumsum().plot(ax=ax[2], color='r')

ax1.hlines(0, *ax1.get_xlim(), 'k', ls='--', lw=2)
ax2.hlines(0, *ax1.get_xlim(), 'k', ls='--', lw=2)
ax3.hlines(0, *ax1.get_xlim(), 'k', ls='--', lw=2)

ax0.set_ylabel('Soil moisture [%]', fontsize=12, labelpad=24)
ax1.set_ylabel('12-day SWE error [m]', fontsize=12)
ax2.set_ylabel('Cumulative SWE error [m]', fontsize=14)
ax3.set_ylabel('Cumulative SWE error [m]', fontsize=14)

good_xticks = ax3.get_xticks()
good_xticklabels = ax3.get_xticklabels()

ax0.tick_params(labelbottom=False)
ax2.tick_params(axis='x', which='minor', bottom=False, labelsize=12)
ax3.tick_params(axis='x', which='minor', bottom=False, labelsize=12)
ax2.tick_params(axis='both', labelsize=12)
ax3.tick_params(axis='both', labelsize=12)

for axx in [ax0, ax1]:
    axx.tick_params(axis='x', which='minor', bottom=False)
    axx.set_xlabel('')
    axx.set_xlim(ax2.get_xlim())
    axx.set_xticks(good_xticks)
    axx.set_xticklabels(good_xticklabels, fontsize=12, rotation=0, ha='center')
    axx.set_yticklabels(axx.get_yticklabels(), fontsize=12)
    
ax0.text(0.02, 0.88, 'a)', transform=ax0.transAxes, fontsize=14)
ax1.text(0.02, 0.05, 'b)', transform=ax1.transAxes, fontsize=14)
ax2.text(0.02, 0.94, 'c)', transform=ax2.transAxes, fontsize=14)
ax3.text(0.02, 0.93, 'd)', transform=ax3.transAxes, fontsize=14)


ax0.set_title('Paradise station\nPeak SWE = 2.37 m on 2021-04-19', linespacing=1.5)
ax2.set_title('Cumulative SWE error from soil permittivity changes\nAll possible 12-day baselines', linespacing=1.5)
ax3.set_title('Cumulative errors from all 12-day baselines\ncondensed into one continuous daily timeseries', linespacing=1.5)
# ax[2].set_title('Cumulative sum of daily soil error\n(not split into 12-day cycles)')

plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure1.png', dpi=600)

## Figure 2 - single season timeseries plots

In [None]:
sites = pd.read_csv('/pl/active/palomaki-sar/insar_swe_errors/data/snotel/fig4_sites.csv').dropna(subset='ecoregion')
# site = sites.loc[sites['station_id']==station_id]

fig, ax = plt.subplots(3, 2, figsize=(11.5*0.8,14*0.8))#, gridspec_kw={'wspace':0.4,'hspace':0.3})
paradise = plot_swe_curves(station_id=679, wateryear=2021, ax1=ax[0,0], ax2=ax[0,1], extra_title_text='\n(Deep snowpack)', return_data=True, plot_legends=False)
trial = plot_swe_curves(station_id=828, wateryear=2017, ax1=ax[1,0], ax2=ax[1,1], extra_title_text='\n(Moderate snowpack)', return_data=True) #828 2017
disaster = plot_swe_curves(station_id=445, wateryear=2019, ax1=ax[2,0], ax2=ax[2,1], extra_title_text='\n(Shallow snowpack)', return_data=True, plot_legends=False)
ymin = min([ax.get_ylim()[0] for ax in ax[0]])
ymax = max([ax.get_ylim()[1] for ax in ax[0]])
for axx in ax[:,0]:
    axx.set_ylim([-0.6, 3.1])
for axx in ax[:,1]:
    axx.set_ylim([-0.16, 0.16])
    
ax[0,0].hlines(0, *ax[0,0].get_xlim(), lw=1.5, color='k', ls='--', zorder=0)
ax[1,0].hlines(0, *ax[1,0].get_xlim(), lw=1.5, color='k', ls='--', zorder=0)
ax[2,0].hlines(0, *ax[2,0].get_xlim(), lw=1.5, color='k', ls='--', zorder=0)

    
plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure2.png', dpi=600)

## Analysis for Figures 3 and 4 - calculate cumulative rolling errors

In [None]:
sites = gpd.read_file('local/median_errors_accum_only_nonabs.csv')
# sites = gpd.read_file('local/median_errors_accum_only_nonabs_simplemedian.csv')
sites['ecoregion'].replace('', np.nan, inplace=True)
sites.dropna(subset='ecoregion', inplace=True)
for c in sites.columns:
    try:
        sites[c] = pd.to_numeric(sites[c])
    except:
        continue
        
sites = sites[['ecoregion', 'station_name', 'station_id', 'lat', 'lon', 'elev', 'sand',
               'clay', 'sturm', 'landfire', 'canopy_height', 'state', 'timezone']]
sites[['defo_error_cumsum_rolling','soil_error_cumsum_rolling','veg_error_cumsum_rolling',
       'dry_atmo_error_cumsum_rolling','wet_atmo_error_cumsum_rolling','ion_error_cumsum_rolling']] = np.nan

dowy = 183 # 152=March 1, 183=April1

for i, site in sites.iterrows():
    try:
        gb_tmp = calculate_avg_cumsum_errors(site['station_id'])
        avg_errors = gb_tmp.mean().loc[dowy,['defo_error_cumsum_rolling','soil_error_cumsum_rolling','veg_error_cumsum_rolling',
                         'dry_atmo_error_cumsum_rolling','wet_atmo_error_cumsum_rolling','ion_error_cumsum_rolling','swe_m']]# / gb_tmp.mean().loc[183, 'swe_m']
        avg_errors.rename(index={'swe_m':'accumulated_swe_m'})
    #     avg_errors['total_error'] = avg_errors.sum()
        sites.loc[i, avg_errors.index] = avg_errors
    except:
        print('bad site')
        continue

sites.to_csv('data/median_cumulative_errors_april1.csv', index=False)

In [None]:
sites = gpd.read_file('local/median_errors_accum_only_nonabs.csv')
sites['ecoregion'].replace('', np.nan, inplace=True)
sites.dropna(subset='ecoregion', inplace=True)
sites['map_number'] = [1,3,7,8,4,9,10,2,5,6,11,12,13]
for c in sites.columns:
    try:
        sites[c] = pd.to_numeric(sites[c])
    except:
        continue
        
sites = sites[['ecoregion', 'station_name', 'station_id', 'lat', 'lon', 'elev', 'sand',
               'clay', 'sturm', 'landfire', 'canopy_height', 'state', 'timezone','map_number']]
sites[['defo_error_cumsum_rolling','soil_error_cumsum_rolling','veg_error_cumsum_rolling',
       'dry_atmo_error_cumsum_rolling','wet_atmo_error_cumsum_rolling','ion_error_cumsum_rolling','total_error']] = np.nan

dowy = 183 # 152=March 1, 183=April1
error_df = pd.DataFrame(index=sites['station_name'], columns=np.arange(2016,2026), dtype=float)
swe_df = pd.DataFrame(index=sites['station_name'], columns=np.arange(2016,2026), dtype=float)

for i, site in sites.iterrows():
    gb_tmp = calculate_avg_cumsum_errors(site['station_id'])
    error_tmp = gb_tmp.get_group(dowy)['cumsum_no_ion'] 
    swe_tmp = gb_tmp.get_group(dowy)['swe_m']
    error_tmp.index = error_tmp.index.year
    swe_tmp.index = swe_tmp.index.year
    
    error_df.loc[site['station_name']] = error_tmp
    swe_df.loc[site['station_name']] = swe_tmp
    

error_df['map_number'] = pd.Series(sites['map_number'].values, index=sites['station_name'])
swe_df['map_number'] = pd.Series(sites['map_number'].values, index=sites['station_name'])

error_df.to_csv('data/apr1_cumulative_nonion_error_by_year.csv')
swe_df.to_csv('data/apr1_swe_by_year.csv')

## Figure 3 - boxplots

In [None]:
error_df = pd.read_csv('data/apr1_cumulative_nonion_error_by_year.csv', index_col=0)
swe_df = pd.read_csv('data/apr1_swe_by_year.csv', index_col=0)
error_df = error_df.astype(float)
error_cols = [c for c in error_df.columns if 'map' not in c]
rel_error_df = error_df[error_cols] / swe_df[error_cols] * 100

fig, ax = plt.subplots(figsize=(2.5*2, 6.5*2))
error_df.sort_values('map_number', ascending=False).drop(columns='map_number').T.boxplot(ax=ax, showfliers=False, grid=False, vert=False)

ax.tick_params(axis='x', rotation=90)
ax.set_xlabel('Apr 1 cumulative\nnon-ion SWE error [m]', fontsize=14, labelpad=12)
ax.vlines((0), *ax.get_ylim(), ls='--', color='k', alpha=0.8, lw=1, zorder=0)
ax.tick_params(axis='x', labelsize=12, rotation=0)
ax.tick_params(axis='y', labelsize=14)
ax.set_xlim([-0.21, 0.21])
ax.set_xticks(np.arange(-0.15, 0.16, 0.1), minor=True)
ax.set_yticklabels(l.get_text().replace(' ','\n',1) for l in ax.get_yticklabels())                      
plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure3.png', dpi=600)

## Figure 4 - map

In [None]:
sites = gpd.read_file('data/median_cumulative_errors_april1.csv')
sites['ecoregion'].replace('', np.nan, inplace=True)
sites.dropna(subset='ecoregion', inplace=True)
sites['map_number'] = [1,3,7,8,4,9,10,2,5,6,11,12,13]
for c in sites.columns:
    try:
        sites[c] = pd.to_numeric(sites[c])
    except:
        continue

sites['total_error_non_ion'] = sites[['wet_atmo_error_cumsum_rolling','dry_atmo_error_cumsum_rolling',
                                      'soil_error_cumsum_rolling','veg_error_cumsum_rolling',
                                      'defo_error_cumsum_rolling']].sum(axis=1)
sites['total_error'] = sites[['wet_atmo_error_cumsum_rolling','dry_atmo_error_cumsum_rolling',
                              'soil_error_cumsum_rolling','veg_error_cumsum_rolling',
                              'defo_error_cumsum_rolling','ion_error_cumsum_rolling']].sum(axis=1)

        
sites['geometry'] = gpd.points_from_xy(sites['lon'], sites['lat'], crs='epsg:4326')
sites = gpd.GeoDataFrame(sites).to_crs('epsg:3857')
sites = sites.loc[sites['ecoregion']!='']

countries = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_0_boundary_lines_land',
        scale='10m',
        facecolor='none')

states_provinces = cfeature.NaturalEarthFeature(
        category='cultural',
        name='admin_1_states_provinces_lines',
        scale='10m',
        facecolor='none')

fig = plt.figure(figsize=(10,10))
ax = plt.axes(projection=ccrs.epsg(3857))
ax.coastlines()
# ax.
ax.add_feature(countries)
ax.add_feature(states_provinces)
minx, miny, maxx, maxy = sites.total_bounds
minx *= 1.03
maxx *= 0.96
miny *= 0.9
maxy *= 1.05
ax.set_extent((minx, maxx, miny, maxy), crs=ccrs.epsg(3857))

color_dict = {1:'#8d00bb',2:'#0e8bff',3:'#fb0006',4:'#dfd826',5:'#e96c20',6:'#1bda02',7:'#9a9a9a'}
# cbar.ax.xaxis.set_ticks(np.arange(1.5,7,1), labels=['Tundra','Boreal\nForest','Maritime','Ephemeral','Prairie','Montane\nForest'])
sturm_colors_sites = [color_dict[x] for x in sites['sturm']]
sites['sturm_colors'] = sturm_colors_sites

sturm_colors = list(color_dict.values())[:-1]
sturm_cmap = mcolors.ListedColormap(sturm_colors)
bounds = np.arange(1,8)
sturm_norm = mcolors.BoundaryNorm(bounds, sturm_cmap.N)

sites['geometry'].plot(ax=ax, color=sites['sturm_colors'], markersize=100, edgecolor='k', linewidth=2, zorder=100)
ax.scatter((),(), marker='o', s=100, color=color_dict[3], edgecolor='k', linewidth=1.5, label='Maritime')
ax.scatter((),(), marker='o', s=100, color=color_dict[6], edgecolor='k', linewidth=1.5, label='Montane\nForest')
ax.scatter((),(), marker='o', s=100, color=color_dict[4], edgecolor='k', linewidth=1.5, label='Ephemeral')
ax.scatter((),(), marker='o', s=100, color=color_dict[2], edgecolor='k', linewidth=1.5, label='Boreal\nForest')
ax.scatter((),(), marker='o', s=100, color=color_dict[5], edgecolor='k', linewidth=1.5, label='Prairie')
# Add legend below



sites['ax_x'] = (sites['geometry'].x - minx) / (maxx-minx)
sites['ax_y'] = (sites['geometry'].y - miny) / (maxy-miny)
# colors = [plt.get_cmap('tab10')(i) for i in range(8)]
colors = {'defo_error_cumsum':plt.get_cmap('tab10')(3), 'soil_error_cumsum':plt.get_cmap('tab10')(5),
          'veg_error_cumsum':plt.get_cmap('tab10')(2), 'dry_atmo_error_cumsum':plt.get_cmap('tab10')(1),
          'wet_atmo_error_cumsum':plt.get_cmap('tab10')(0), 'cumsum_no_ion':'k'}
colors = ['k', plt.get_cmap('tab10')(0), plt.get_cmap('tab10')(1), plt.get_cmap('tab10')(5), plt.get_cmap('tab10')(2), plt.get_cmap('tab10')(3)]
widths = [0.8, 0.8, 0.8, 0.8, 0.8, 0.8]


hline_dict = {0.15:[-0.1,0.1],0.2:[-0.1,0.1],0.3:[-0.15,0.15],0.5:[-0.25,0.25],1.0:[-0.5,0.5],2.0:[-1,1],3.0:[-1.5,1.5],5.0:[-2.5,2.5],10:[-5,-5],20:[-10,10],50:[-25,25]}

plot_ionosphere = False


if plot_ionosphere:
    n_errors = 8
    error_order = ['total_error','total_error_non_ion','ion_error_cumsum_rolling','wet_atmo_error_cumsum_rolling',
               'dry_atmo_error_cumsum_rolling','soil_error_cumsum_rolling','veg_error_cumsum_rolling',
               'defo_error_cumsum_rolling']
    error_labels = ['Total error','Total non-ion error','Ionosphere','Wet troposphere','Dry troposphere','Soil perm','Veg perm','Deformation']
else:
    n_errors = 6
    error_order = ['total_error_non_ion','wet_atmo_error_cumsum_rolling','dry_atmo_error_cumsum_rolling',
                   'soil_error_cumsum_rolling','veg_error_cumsum_rolling','defo_error_cumsum_rolling']
    error_labels = ['Total non-ion error','Wet troposphere','Dry troposphere','Soil permittivity','Veg permittivity','Deformation']
    
bar_locations = np.arange(n_errors)
    
ylim = 0.1

for i, site in sites.iterrows():
    # Create larger patches behind axes
    avg_errors = site[error_order] / site['swe_m']
#     ax_mini_patch = ax.inset_axes([site['ax_x']+0.01, site['ax_y']+0.01, 0.115, 0.09], transform=ax.transAxes)
#     ax_mini_patch.spines[['top','right','left','bottom']].set_visible(False)
#     ax_mini_patch.tick_params(bottom=False, left=False, labelbottom=False, labelleft=False)
#     ax_mini_patch.patch.set_alpha(1.0)
    # Create actual axes
#     ax_mini = ax.inset_axes([site['ax_x']+0.01, site['ax_y']+0.01, 0.075, 0.09], transform=ax.transAxes)
    ax_mini = ax.inset_axes([site['ax_x']+0.01, site['ax_y']+0.01, 0.09, 0.09], transform=ax.transAxes)
    ax_mini.patch.set_alpha(1.0)
    ax_mini.spines[['top','right','left']].set_visible(True)

    triangles = []
    for err_type, err in avg_errors.items():
        if err > ylim:
            triangles.append('^')
            avg_errors.loc[err_type] = ylim*0.8
        elif err < -ylim:
            triangles.append('v')
            avg_errors.loc[err_type] = -ylim*0.8
        else:
            triangles.append('')
            
    
    ax_mini.bar(bar_locations, avg_errors, color=colors, width=widths)
    for i, tri in enumerate(triangles):
        if tri:
            multiplier = -1 if tri == 'v' else 1
            ax_mini.plot(i, ylim*multiplier*0.9, tri, color=colors[i], ms=5.8)
    # Calculate positions for hlines 
    max_limit = avg_errors.abs().max()
#     print(max_limit)
    hlines = [-0.05, 0.05]
    ylim = 0.1
    for k, v in hline_dict.items():
        if max_limit > ylim and max_limit <= k:
            hlines = v
            ylim = k
    hlines = [-0.05,0.05]
    ylim = 0.1
    ax_mini.hlines(0.0, -0.8, n_errors-0.2, lw=1.5, color='k', ls='-', alpha=0.9)
    ax_mini.hlines(hlines, -0.8, n_errors-0.2, lw=1.5, color='k', ls=':', alpha=0.7)
    ax_mini.set_yticks(hlines)
    ax_mini.yaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))
#     ax_mini.tick_params(axis='both', labelbottom=False, labelleft=False, bottom=False, left=False, labelright=True, pad=-2.5, labelsize=7)
    ax_mini.tick_params(axis='both', labelbottom=False, labelleft=False, bottom=False, left=False, labelright=False)
    for l in ax_mini.get_yticklabels():
        l.set_stretch('ultra-condensed')
#         l.set_color('k')
#         l.set_weight('bold')
    ax_mini.text(site['ax_x']-0.015, site['ax_y'], site['map_number'], color='white', fontsize=14, weight='bold', ha='right', va='center', transform=ax.transAxes,  path_effects=[pe.withStroke(linewidth=4, foreground='black')])

    ax_mini.set_ylim([-ylim, ylim])
    ax_mini.set_xlim(-1, n_errors)
                            
# ax_legend = ax.inset_axes([0.66, 0.75, 0.3, 0.2], transform=ax.transAxes) # Upper right
# ax_legend = ax.inset_axes([0.355, 0.025, 0.3, 0.23], transform=ax.transAxes) # bottom center
ax_legend = ax.inset_axes([0.18, 0.025, 0.28, 0.25], transform=ax.transAxes) # [0.1, 0.025, 0.3, 0.25]
ax_legend.patch.set_alpha(1.0)
# ax_legend.patch.set_alpha(0.0)
ax_legend.spines[['top','right','left','bottom']].set_visible(True)
# ax_legend.bar(np.arange(7), np.tile(1,7), color=colors_non_ion)
ax_legend.bar(np.arange(n_errors), np.tile(1,n_errors), color=colors[:n_errors])
# ax_legend.text(0.5, 1, 'Legend', ha='center', transform=ax_legend.transAxes)
ax_legend.text(0.5, 0.05, 'Cumulative error types', fontsize=12, ha='center', transform=ax_legend.transAxes)
ax_legend.set_ylim(-0.25, 1.1)
ax_legend.tick_params(axis='both', labelbottom=False, labelleft=False, bottom=False, left=False)

for i, text in enumerate(error_labels):
    if 'Total' in text:
        c = 'w'
    else:
        c = 'k'
    ax_legend.text(i-0.18, 0.1, text, rotation=90, c=c)
    
# Add legend for dashed lines
ax_key = ax.inset_axes([0.48, 0.025, 0.23, 0.25], transform=ax.transAxes) # [0.425, 0.025, 0.23, 0.25]
ax_key.patch.set_alpha(1.0)
ax_key.tick_params(axis='both', labelbottom=False, labelleft=False, bottom=False, left=False)
ax_key.set_ylim([-ylim, ylim])
ax_key.set_xlim(-1, n_errors)
# ax_key.hlines(0.0, -0.8, n_errors-2-0.2, lw=1.5, color='k', ls='-', alpha=0.9)
# ax_key.hlines(hlines, -0.8, n_errors-2-0.2, lw=1.5, color='k', ls=':', alpha=0.7)
ax_key.hlines(0.0, 1.8, n_errors-0.2, lw=1.5, color='k', ls='-', alpha=0.9)
ax_key.hlines(hlines, 1.8, n_errors-0.2, lw=1.5, color='k', ls=':', alpha=0.7)
ax_key.text(0.4, 0, '  0%', va='center')
ax_key.text(0.4, 0.05, '$+$5%', va='center')
ax_key.text(0.4, 0.0905, '$+$10%', va='center')
ax_key.text(0.4, -0.05, '$-$5%', va='center')
ax_key.text(0.4, -0.092, '$-$10%', va='center')
ax_key.text(3.8, 0.02, 'Inset axis key', fontsize=9, ha='center', bbox=dict(linewidth=1, edgecolor='k', facecolor='#bababa'))
# ax_key.text(-0.5, 0.175, 'Errors relative to\nstation SWE', bbox=dict(linewidth=1, edgecolor='k', facecolor='#bababa'))
ax_key.text(-0.2, 0, 'Errors relative to\nApril 1 SWE', fontsize=9, ha='center', va='center', rotation=90)#, bbox=dict(linewidth=1, edgecolor='k', facecolor='#bababa'))
        
swann_avg = xr.open_dataarray('/pl/active/palomaki-sar/insar_swe_errors/data/ancillary/swann_swe/swann_april1_avg.nc', decode_coords='all')

    
# sturm = rxr.open_rasterio('/pl/active/palomaki-sar/insar_swe_errors/data/ancillary/SnowClass_NA_300m_10.0arcsec_2021_v01.0.tif').rio.reproject('epsg:3857').rio.clip_box(minx, miny, maxx, maxy)
# sturm = sturm.where(sturm<8)
# mesh = sturm.plot(ax=ax, zorder=0, cmap=sturm_cmap, norm=sturm_norm, add_colorbar=False)
mesh = (swann_avg/1000).plot(ax=ax, zorder=0, cmap='Blues', add_colorbar=False, vmin=0, vmax=2)
cbar = plt.colorbar(mesh, extend='max', orientation='horizontal', shrink=0.895, pad=0.02)
# cbar.ax.xaxis.set_ticks(np.arange(1.5,7,1), labels=['Tundra','Boreal\nForest','Maritime','Ephemeral','Prairie','Montane\nForest'])
# cbar.ax.set_xlabel('Sturm and Liston (2021) snow class (background map)', labelpad=18, fontsize=16)
cbar.ax.set_xlabel('Average April 1 SWE [m]', labelpad=18, fontsize=14)
cbar.ax.tick_params(axis='x', which='major', labelsize=12)
cbar.ax.tick_params(axis='x', which='minor', bottom=False)
ax.set_title('')

pos = ax_legend.get_position()

ax_snowclass = ax.inset_axes([0.01, 0.025, 0.15, 0.25], transform=ax.transAxes)
ax_snowclass.patch.set_alpha(1.0)
ax_snowclass.spines[['top','right','left','bottom']].set_visible(True)
ax_snowclass.tick_params(axis='both', bottom=False, left=False,
                         labelbottom=False, labelleft=False)

# Add a title like ax.legend(title=...)
# ax_snowclass.set_title('Snow Class', fontsize='large')

# Define labels and colors in the order you want them
legend_items = [
    ('Maritime',        color_dict[3]),
    ('Montane\nForest', color_dict[6]),
    ('Ephemeral',       color_dict[4]),
    ('Boreal\nForest',  color_dict[2]),
    ('Prairie',         color_dict[5])
]

# Draw scatter markers and text manually
for i, (label, color) in enumerate(legend_items):
    y = 0.92 - (i+1)*0.17  # vertical spacing between entries
    ax_snowclass.scatter(0.13, y, marker='o', s=100, color=color,
                         edgecolor='k', linewidth=1.5, transform=ax_snowclass.transAxes)
    ax_snowclass.text(0.25, y, label, transform=ax_snowclass.transAxes,
                      va='center', ha='left', fontsize=10)

# Lock the viewbox
ax_snowclass.text(0.5, 0.9, 'Snow class', ha='center', va='center', fontsize=12, transform=ax_snowclass.transAxes)
ax_snowclass.set_xlim(0, 1)
ax_snowclass.set_ylim(0, 1)

plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure4.png', dpi=1200)
# fig.savefig('local/figs/figure4.pdf')

## Figure 5 - seasonal errors

In [None]:
extra_title_dict = {'Paradise':' (Deep snowpack)','Trial Lake':' (Moderate snowpack)','Disaster Peak':' (Shallow snowpack)'}

fig, ax = plt.subplots(6, 3, figsize=(15*0.8,24*0.6))
plot_temporal_variability(station_id=679, ax=ax[:,0], extra_title_dict=extra_title_dict, return_data=False)
plot_temporal_variability(station_id=828, ax=ax[:,1], extra_title_dict=extra_title_dict, return_data=False)
plot_temporal_variability(station_id=445, ax=ax[:,2], extra_title_dict=extra_title_dict, return_data=False)

for axx in ax[0]:
    axx.set_ylim(-0.011, 0.011)
    
for axx in ax[1]:
    axx.set_ylim(-0.023, 0.023)
    
for axx in ax[2]:
    axx.set_ylim(-0.023, 0.023)
    
for axx in ax[3]:
    axx.set_ylim(-0.044, 0.044)
    
for axx in ax[4]:
    axx.set_ylim(-0.063, 0.063)
    
for axx in ax[5]:
    axx.set_ylim(-2.3, 2.3)
    


for i in range(6):
    for axx in ax[i]:
        axx.hlines(0, 1, 274, lw=1.5, color='k', ls='--', zorder=0)
        axx.set_xlim(1,274)


for axx in ax[:,0]:
    axx.set_ylabel('SWE error [m]', fontsize=10)

plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure5.png', dpi=600)

## Figure 6 - exceedance curves

In [None]:
df = pd.read_csv('data/big_timeseries.csv', index_col=[0,1], parse_dates=True)
dowy_thresh = 183 # apr1 = 183
df = df.loc[df['dowy'] <= dowy_thresh]

cdfs = pd.DataFrame(index=np.arange(len(df.index)))
error_cols = ['total_error','non_ion_error','ion_error','wet_atmo_error','dry_atmo_error','soil_error','veg_error','defo_error']
labels = ['Total error','Total non-ion','Ionosphere','Wet tropo','Dry tropo','Soil perm','Veg perm','Deformation']
colors = ['r','k'] + [plt.get_cmap('tab10')(i) for i in range(6)]

color_dict = {'total_error':'r','non_ion_error':'k','defo_error':plt.get_cmap('tab10')(3), 'soil_error':plt.get_cmap('tab10')(5),
          'veg_error':plt.get_cmap('tab10')(2), 'dry_atmo_error':plt.get_cmap('tab10')(1),
          'wet_atmo_error':plt.get_cmap('tab10')(0), 'ion_error':plt.get_cmap('tab10')(4)}

for err in error_cols:
    abs_err_sorted = df[err].abs().dropna().sort_values().reset_index(drop=True)
    cdf = (1 - abs_err_sorted.rank(method='first') / len(abs_err_sorted)) * 100
    cdfs[f'{err}_sorted'] = abs_err_sorted
    cdfs[f'{err}_cdf'] = cdf.reset_index(drop=True)
    
# CDFs relative to dSWE
dswe_thresh = 0.01
df_accum = df.loc[df['swe_change'] >= dswe_thresh]
cdfs_accum = pd.DataFrame(index=np.arange(len(df_accum.index)))
for err in error_cols:
    err_valid = df_accum[err].dropna()
    err_tmp = err_valid / df_accum.loc[err_valid.index, 'swe_change']
    abs_err_sorted = err_tmp.abs().sort_values().reset_index(drop=True)
    cdf = (1 - abs_err_sorted.rank(method='first') / len(abs_err_sorted)) * 100
    cdfs_accum[f'{err}_sorted'] = abs_err_sorted
    cdfs_accum[f'{err}_cdf'] = cdf



fig, ax = plt.subplots(1, 2, figsize=(12, 5))
for i, err in enumerate(error_cols):
    if err == 'total_error' or err == 'non_ion_error':
        lw = 3.5
        zorder=10
    else:
        lw = 1.5
        zorder=1
    ax[0].plot(cdfs[f'{err}_sorted'], cdfs[f'{err}_cdf'], lw=lw, color=color_dict[err], label=labels[i], zorder=zorder)
#     ax[1].plot(cdfs_accum[f'{err}_sorted'], cdfs_accum[f'{err}_cdf'], lw=lw, color=color_dict[err], label=labels[i], zorder=zorder)
    
idx_cdfs = (cdfs['total_error_cdf'] - 50).abs().idxmin()
idx_cdfs_accum = (cdfs_accum['total_error_cdf'] - 50).abs().idxmin()

ax[0].scatter((cdfs.loc[idx_cdfs, ['total_error_sorted','non_ion_error_sorted']]), (50,50), marker='o', c='none', linewidths=2.0, edgecolors='0.4', s=80, zorder=100)
    
ax[0].legend(ncols=2)
ax[0].set_xlabel('12-day error magnitude [m]', fontsize=14, labelpad=12)
ax[0].set_ylabel('Exceedance probability [%]', fontsize=14, labelpad=6)
ax[0].set_xlim([-0.01, 0.31])
ax[0].set_ylim([-2, 102])
ax[0].set_yticks([0,20,40,50,60,80,100])
# ax[0].set_xticks(np.arange(0.05, 0.31, 0.1), minor=True)
ax[0].tick_params(axis='both', labelsize=12)
ax[0].grid(False, which='both')
ax[0].set_title('All stations, WYs 2016$-$25, October 1$-$April 1 (n = 23,397)')
ax[0].hlines(50, *ax[0].get_xlim(), ls=':', color='k', lw=2, alpha=0.8, zorder=0)
ax[0].vlines(cdfs.loc[idx_cdfs, ['total_error_sorted','non_ion_error_sorted']], ax[0].get_ylim()[0], 50, ls=':', color='k', lw=2, alpha=0.8, zorder=0)

# CDFs relative to dSWE
df_dict = {}
labels = []

for thresh in [0.01, 0.02, 0.05, 0.1, 0.2]:
    dswe_thresh = thresh
    df_accum = df.loc[df['swe_change'] >= dswe_thresh]
    cdfs_accum = pd.DataFrame(index=np.arange(len(df_accum.index)))
    for err in error_cols:
        err_valid = df_accum[err].dropna()
        err_tmp = err_valid / df_accum.loc[err_valid.index, 'swe_change']
        abs_err_sorted = err_tmp.abs().sort_values().reset_index(drop=True)
        cdf = (1 - abs_err_sorted.rank(method='first') / len(abs_err_sorted)) * 100
        cdfs_accum[f'{err}_sorted'] = abs_err_sorted
        cdfs_accum[f'{err}_cdf'] = cdf
    thresh_str = f'{float(thresh): .2f}'
    df_dict[f'thresh_{thresh_str}'] = cdfs_accum
    labels.append(f'$\Delta$SWE ≥{thresh_str}m (n = {cdfs_accum["non_ion_error_sorted"].notnull().sum()})')

    
cmap = plt.get_cmap('Blues')
colors = cmap(np.linspace(0,1,7))

for i, (thresh_str, df) in enumerate(df_dict.items()):
    ax[1].plot(df[f'non_ion_error_sorted'], df[f'non_ion_error_cdf'], lw=2, color=colors[i+2], label=labels[i])    
    idx_cdfs_accum = (df['total_error_cdf'] - 50).abs().idxmin()
    ax[1].scatter((df.loc[idx_cdfs_accum, 'non_ion_error_sorted']), 50, marker='o', c='none', linewidths=2.0, edgecolors='0.4', s=80, zorder=100)
    ax[1].vlines(df.loc[idx_cdfs_accum, 'non_ion_error_sorted'], ax[1].get_ylim()[0], 50, ls=':', color='k', lw=2, alpha=0.6, zorder=0)

ax[1].legend(ncols=1, loc='upper right')
ax[1].set_xlabel('12-day error magnitude [% relative to $\Delta$SWE]', fontsize=14, labelpad=12)
ax[1].set_ylabel('Exceedance probability [%]', fontsize=14, labelpad=6)
ax[1].set_xlim([-0.02, 1.02])
ax[1].set_ylim([-2, 102])
ax[1].set_yticks([0,20,40,50,60,80,100])
# ax[1].set_xticks(np.arange(0.5, 4.51, 1), minor=True)
ax[1].tick_params(axis='both', labelsize=12)
ax[1].grid(False, which='both')
ax[1].set_title('Relative total non-ion error for $\Delta$SWE ≥ threshold')
ax[1].hlines(50, *ax[1].get_xlim(), ls=':', color='k', lw=2, alpha=0.8, zorder=0)

ax[1].xaxis.set_major_formatter(mtick.PercentFormatter(1.0, decimals=0))


plt.tight_layout()
# fig.savefig('local/figs/initial_submission/figure6.png', dpi=600)

## Figure 7 - transects

See code in `src/analysis/play/transect_figure.ipynb`