In [None]:
import dask
from dask.distributed import Client, LocalCluster, Lock
import geopandas as gpd
import matplotlib.pyplot as plt
import utils
import importlib
from glob import glob
import os
import xdem
import numpy as np
import rioxarray as rio
from itertools import product
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd

In [None]:
aoi_gdf = gpd.read_file('../is_elevation_aoi.geojson')
aoi = aoi_gdf.geometry[0]

In [None]:
cluster = LocalCluster()
client = Client(cluster)

In [None]:
client.shutdown()
cluster.close()

In [None]:
_ = importlib.reload(utils)
e = utils.Elevation(aoi)

In [None]:
e.stack()

In [None]:
import seaborn as sns
def plot_stat(df, stat, ax):
    # df = get_meta_df(d)
    after = f'after_{stat}'
    before = f'before_{stat}'
    for row in df.iterrows():
        # print(row[0])
        ax.annotate("",
                    xy=(row[1]['to_register_date'], row[1][after]),
                    xytext=(row[1]['to_register_date'], row[1][before]),
                    arrowprops=dict(arrowstyle='->'))
    
    minx, maxx = df['to_register_date'].agg(['min','max'])
    delta = pd.Timedelta('90d')
    ax.set_xlim(minx-delta, maxx+delta)
    ax.set_ylim(*df[[after, before]].melt()['value'].agg(['min','max']))
    
    ax.axhline(0, c='lightgrey', lw=0.5)
    ax.axvline(df.reference_date.unique()[0], ls=':', c='k')
    ax.set_ylabel(f'{stat} (m)')
    for label in ax.get_xticklabels(which='major'):
        label.set(rotation=30, horizontalalignment='center')

def plot_coreg(df, ax):
    # df = get_meta_df(d)
    
    plot_df = df[['median_after','median_before','nmad_after','nmad_before']].melt()
    plot_df['when'] = plot_df['variable'].str.split('_').apply(lambda x: x[1])
    plot_df['variable'] = plot_df['variable'].str.split('_').apply(lambda x: x[0])
    sns.violinplot(data=plot_df,
                   x='variable',
                   y='value',
                   hue='when',
                   palette=sns.palettes.color_palette('colorblind')[-2:],
                   hue_order=['before','after'],
                   ax=ax)
    ax.set_ylabel('metres')
    ax.set_xlabel(None)
    ax.axhline(0, c='lightgrey', lw=0.5, zorder=0)
    sns.move_legend(ax, loc='best', title=None)
    

In [None]:
from scipy.stats import theilslopes
def robust_slope(y, t):
    '''
    for robust trends using theilslopes
    y - input array of variable of concern
    t - array of corresponding timestamps
        converts timestamps to years since first observation
        identify nan values in `y`, return theilslopes for non-nan values
    '''
    x = (t-t.min()) / pd.Timedelta('365.25D')
    idx = np.isnan(y)  # .compute()
    # print(idx.shape)
    if len(idx) == idx.sum():
        return np.stack((np.nan, np.nan, np.nan, np.nan),
                        axis=-1)
    else:
        slope, intercept, low, high = theilslopes(y[~idx], x[~idx])
        return np.stack((slope, intercept, low, high),
                        axis=-1)


def make_robust_trend(ds, inp_core_dim='time'):
    '''
    robust_slope as ufunc to dask array, dss
    this is a lazy operation
    --> very helpful SO
    https://stackoverflow.com/questions/58719696/
    how-to-apply-a-xarray-u-function-over-netcdf-and-return-a-2d-array-multiple-new
    /62012973#62012973
    --> also helpful:https://stackoverflow.com/questions/71413808/
    understanding-xarray-apply-ufunc
    --> and this:
    https://docs.xarray.dev/en/stable/examples/
    apply_ufunc_vectorize_1d.html#apply_ufunc
    '''
    output = xr.apply_ufunc(robust_slope,
                            ds,
                            ds[inp_core_dim],
                            input_core_dims=[[inp_core_dim],
                                             [inp_core_dim]],
                            output_core_dims=[['result']],
                            exclude_dims=set([inp_core_dim]),
                            vectorize=True,
                            dask='parallelized',
                            output_dtypes=[float],
                            dask_gufunc_kwargs={
                                'allow_rechunk': True,
                                'output_sizes': {'result': 4}
                                }
                            )
    
    output['result'] = xr.DataArray(['slope',
                                     'intercept',
                                     'low_slope',
                                     'high_slope'],
                                    dims=['result'])
    
    arrs = []
    for i in range(output.shape[-1]):
        var = output[:,:,i].result.item()
        arrs.append(output[:,:,i].rename(var).drop_vars('result'))
        
    return xr.merge(arrs)

In [None]:
downsampled = xr.open_dataset('../data/arcticDEM/coregd/dem_stack.nc', chunks='auto')

trends = make_robust_trend(
    downsampled['z'].chunk({'time':-1, 'x':500, 'y':500})
    ).compute()

trends.attrs = {'description': '''
                theilslope estimates of surface elevation change
                high_slope and low_slope are 0.95 confidence interval
                ''',
                }

# trends.to_netcdf('../data/arcticDEM/coregd/sec_trend.nc')

In [None]:
fig, ax = plt.subplots()
trends['slope'].plot(robust=True,
                     cmap='RdBu_r',
                     cbar_kwargs={'label':'sec (m/yr)'},
                     ax=ax)
ax.set_title('surface elevation change')

In [None]:
# ## plotting failed to coregister DEMs
# with rio.open_rasterio(e.ref, chunks='auto') as ref:
#     _fillval = ref.attrs['_FillValue']
#     ref_tmp = xr.where(ref.squeeze() != _fillval, ref, np.nan)
#     ref_tmp = ref_tmp.coarsen({'x':100, 'y':100}, boundary='trim').median()
#     for f in list(e.failed.keys()):
#         fig, axs = plt.subplots(ncols=2, subplot_kw={'aspect':'equal'})
        
#         with rio.open_rasterio(f, chunks='auto') as dem:
#             _fillval = dem.attrs['_FillValue']
#             tmp = xr.where(dem.squeeze() != _fillval, dem, np.nan)
#             tmp = tmp.coarsen({'x':100, 'y':100}, boundary='trim').median()
            
#             tmp.plot(ax=axs[0], vmin=200, vmax=2000, cbar_kwargs={'shrink':0.5})
#             axs[0].set_title(os.path.basename(f))
            
#             ref_tmp.plot(ax=axs[1], vmin=200, vmax=2000, cbar_kwargs={'shrink':0.5})
#             axs[1].set_title('reference')
            