# Extract profiles (Aschwanden, Truffer, and Fahnestock, 2016)

In [None]:
import cartopy.crs as ccrs

import geopandas as gp
from joblib import Parallel, delayed
import numpy as np
import pandas as pd
from pathlib import Path
import pylab as plt
from matplotlib import colors, cm
from matplotlib.colors import LightSource
import matplotlib.ticker as mticker
from tqdm.auto import tqdm
import xarray as xr

from pypism import profiles
from pypism.profiles import process_profile
from pypism.utils import preprocess_nc
from pypism.hillshade import hillshade
from pypism.utils import qgis2cmap, tqdm_joblib, blend_multiply

## Resolution along profiles

In [None]:
profile_resolution = 200 # m

## Load profiles and segmentize

In [None]:
profiles_path = Path("../data/greenland-flux-gates.gpkg")
profiles_gp = gp.read_file(profiles_path).rename(columns={"id": "profile_id"})
geom = profiles_gp.segmentize(profile_resolution)
profiles_gp = gp.GeoDataFrame(profiles_gp, geometry=geom)

## Load observed velocities

In [None]:
obs_file = Path("/Users/andy/Google Drive/My Drive/data/ITS_LIVE/GRE_G0240_0000.nc")
obs_ds = xr.open_dataset(obs_file)

## Load Ensemble Experiments

In [None]:
pism_files = list(Path("/Users/andy/Google Drive/My Drive/Projects/gris-calib/data").glob("velsurf_mag_gris*.nc"))
sim_ds = xr.open_mfdataset(pism_files, 
                  preprocess=preprocess_nc,
                  concat_dim="exp_id",
                  combine="nested",
                  parallel=True)

In [None]:
from typing import List, Tuple, Union

import geopandas as gp
import numpy as np
import pandas as pd
import pylab as plt
import seaborn as sns
import xarray as xr

@xr.register_dataset_accessor("profiles")
class CustomDatasetMethods:
    """
    Custom Dataset Methods
    """

    def __init__(self, xarray_obj):
        """
        Init
        """
        self._obj = xarray_obj

    def init(self):
        """
        Do-nothing method

        Needed to work with joblib Parallel
        """

    def add_normal_component(
        self,
        x_component: str = "vx",
        y_component: str = "vy",
        normal_name: str = "v_normal",
    ) -> xr.Dataset:
        """
        Add normal component
        """
        assert (x_component and y_component) in self._obj.data_vars

        def func(x, x_n, y, y_n):
            return x * x_n + y * y_n

        self._obj[normal_name] = xr.apply_ufunc(
            func,
            self._obj[x_component],
            self._obj["nx"],
            self._obj[y_component],
            self._obj["ny"],
        )
        return self._obj

    def extract_profile(
        self,
        xs: np.ndarray,
        ys: np.ndarray,
        profile_name: str = "Glacier X",
        data_vars: Union[None, List[str]] = None,
    ) -> xr.Dataset:
        """
        Extract a profile from a dataset given x and y coordinates.

        Parameters:
        x: x-coordinates of the profile
        y: y-coordinates of the profile
        profile_name: name of the profile
        profile_id: id of the profile
        data_vars: list of data variables to include in the profile. If None, all data variables are included.

        Returns:
        A new xarray Dataset containing the extracted profile.
        """
        buffer = 1_000.0
        
        extent_slice = {"x": slice(np.min(xs) - buffer, np.max(xs) + buffer), "y": slice(np.min(ys) - buffer, np.max(ys) + buffer)}
        with np.errstate(invalid="ignore"):
            profile_axis = np.sqrt((xs - xs[0]) ** 2 + (ys - ys[0]) ** 2)

        x: xr.DataArray
        y: xr.DataArray
        x = xr.DataArray(
            xs,
            dims="profile_axis",
            coords={"profile_axis": profile_axis},
            attrs=self._obj["x"].attrs,
            name="x",
        )
        y = xr.DataArray(
            ys,
            dims="profile_axis",
            coords={"profile_axis": profile_axis},
            attrs=self._obj["y"].attrs,
            name="y",
        )

        das = []
        name = xr.DataArray(
            [profile_name],
            dims="profile_id",
            attrs={"units": "m", "long_name": "distance along profile"},
            name="profile_name",
        )
        das.append(name)

        if data_vars is None:
            data_vars = list(self._obj.data_vars)

        with np.errstate(invalid="ignore"):
            for m_var in data_vars:
                da = self._obj[m_var]
                try:
                    das.append(da.sel(**extent_slice).interp(x=x, y=y, kwargs={"fill_value": np.nan}))
                except:
                    pass

        return xr.merge(das)

    def plot(
        self,
        sigma: float = 1,
        title: Union[str, None] = None,
        obs_var: str = "v",
        obs_error_var: str = "v_err",
        sim_var: str = "velsurf_mag",
        palette: str = "Paired",
        obs_kwargs: dict = {"color": "0", "lw": 1, "marker": "o", "ms": 2},
        obs_error_kwargs: dict = {"color": "0.75"},
        sim_kwargs: dict = {"lw": 1, "marker": "o", "ms": 2},
    ) -> plt.Figure:
        """
        Plot observations and simulations along profile.
        """
        n_exps = len(self._obj["exp_id"])

        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.fill_between(
            self._obj["profile_axis"],
            self._obj[obs_var] - sigma * self._obj[obs_error_var],
            self._obj[obs_var] + sigma * self._obj[obs_error_var],
            **obs_error_kwargs,
        )
        ax.plot(
            self._obj["profile_axis"],
            self._obj[obs_var],
            label="Observed",
            **obs_kwargs,
        )
        palette = sns.color_palette(palette, n_colors=n_exps)
        # Loop over the data and plot each line with a different color
        for i in range(n_exps):
            exp_label = f"""{self._obj["exp_id"][i].to_numpy()} rmsd={self._obj["rmsd"][i].to_numpy():.0f}m/yr"""
            ax.plot(
                self._obj["profile_axis"],
                self._obj[sim_var].isel(exp_id=i).T,
                color=palette[i],
                label=exp_label,
                **sim_kwargs,
            )
        ax.set_xlabel("Distance along profile (m)")
        ax.set_ylabel("Speed (m/yr)")
        legend = ax.legend(loc="upper left")
        legend.get_frame().set_linewidth(0.0)
        legend.get_frame().set_alpha(0.0)

        if title is None:
            title = self._obj["profile_name"].to_numpy()
        plt.title(title)
        return fig


In [None]:
profile = profiles_gp.iloc[0]
process_profile(profile, obs_ds, sim_ds)

## Extract all profiles

In [None]:
n_jobs = 8
with tqdm_joblib(tqdm(desc="Processing profiles", total=len(profiles_gp))) as progress_bar:
    result = Parallel(n_jobs=n_jobs)(
        delayed(process_profile)(profile,
            obs_ds,
            sim_ds
            )
            for _, profile in profiles_gp.iterrows()
        )
obs_profiles = [r[:][0] for r in result]
sims_profiles = [r[:][1] for r in result]
stats_profiles = pd.concat([r[:][-1] for r in result]).reset_index(drop=True)
stats_profiles_points = gp.GeoDataFrame(stats_profiles, geometry=stats_profiles.geometry.centroid)

## Plot profiles

In [None]:
def plot_profile(ds: xr.Dataset):
    fig = ds.profiles.plot(palette="Greens", sigma=1)
    profile_name = ds["profile_name"].to_numpy()
    fig.savefig(f"{profile_name}_profile.pdf")
    plt.close()
    del fig

profiles = [xr.merge([obs_profile.squeeze(), sims_profile.squeeze()]) for obs_profile, sims_profile in zip(obs_profiles, sims_profiles)]
for profile in profiles:
    plot_profile(profile)
# with tqdm_joblib(tqdm(desc="Processing profiles", total=len(profiles_gp))) as progress_bar:
#     result = Parallel(n_jobs=n_jobs)(
#         delayed(plot_profile)(ds)
#             for ds in profiles
#         )

In [None]:
profile = stats_profiles[stats_profiles.index == 0].reset_index(drop=True)
profile_centroid = gp.GeoDataFrame(profile, geometry=profile.geometry.centroid)

In [None]:
def round(x: float, mult: int = 1000):
    return np.round(x / mult) * mult

def figure_extent(x_c: float, y_c: float, x_e: float = 50_000, y_e: float = 50_000):

    return {"x": slice(x_c - x_e / 2, x_c + x_e /2), "y": slice(y_c + y_e / 2, y_c - y_e /2)}

In [None]:

qgis_colormap = Path("../data/speed-colorblind.txt")
cmap = qgis2cmap(qgis_colormap, name="speeds")


def plot_glacier(profile: gp.GeoDataFrame, surface: xr.Dataset, overlay: xr.Dataset, cmap=cmap, vmin: float = 10, vmax: float = 1500):

    def get_extent(ds: xr.Dataset):
        return [ds["x"].values[0], ds["x"].values[-1], ds["y"].values[-1], ds["y"].values[0]]

    profile_centroid = gp.GeoDataFrame(profile, geometry=profile.geometry.centroid)
    glacier_name = profile.iloc[0]["name"]
    exp_id = profile.iloc[0]["exp_id"]
    x_c = round(profile_centroid.geometry[0].x)
    y_c = round(profile_centroid.geometry[0].y)
    extent_slice = figure_extent(x_c, y_c)
    crs = ccrs.NorthPolarStereo(central_longitude=-45, true_scale_latitude=70, globe=None)
    # Shade from the northwest, with the sun 45 degrees from horizontal
    light_source = LightSource(azdeg=315, altdeg=45)
    glacier_overlay = overlay.sel(**extent_slice)
    glacier_surface = surface.interp_like(glacier_overlay)

    extent = get_extent(glacier_overlay)
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    mapper = cm.ScalarMappable(norm=norm, cmap=cmap)

    v = mapper.to_rgba(glacier_overlay.to_numpy())
    z = glacier_surface.to_numpy()
    fig = plt.figure(figsize=(6.2, 6.2))
    ax = fig.add_subplot(111, projection=crs)
    rgb = light_source.shade_rgb(v, elevation=z, vert_exag=0.01, blend_mode=blend_multiply)
    f = ax.imshow(rgb, extent=extent, origin="upper", transform=crs)
    profile.plot(ax=ax, color="k", lw=1)
    profile_centroid.plot(column="pearson_r", vmin=0, vmax=1, cmap="RdYlGn",  
               markersize=30, legend=True,
               missing_kwds={},
               legend_kwds={"shrink": 0.5, "pad": 0.075, "location": "bottom", "label": "Pearson r (1)"}, ax=ax)
    ax.annotate(glacier_name, (x_c, y_c) ,(10, 10), xycoords="data", textcoords="offset points")
    ax.coastlines(linewidth=0.25, resolution="10m")
    ax.gridlines(draw_labels={"top": "x", "right": "y"}, 
                 dms=True,
                 xlocs=np.arange(-50, 0, 1),
                 ylocs=np.arange(50, 88, 1),
                 x_inline=False, y_inline=False, 
                 rotate_labels=20,
                 ls="dotted", color="k")

    ax.set_extent(extent, crs=crs)
    fig.tight_layout()
    fig.savefig(f"{glacier_name}_{exp_id}_speed.pdf")
    plt.close()
    


In [None]:
plot_glacier(profile, gris_ds["surface"], obs_ds["v"])


In [None]:
for index in range(len(stats_profiles)):
    profile = stats_profiles[stats_profiles.index == index].reset_index(drop=True)
    plot_glacier(profile, gris_ds["surface"], obs_ds["v"])

In [None]:

crs = ccrs.NorthPolarStereo(central_longitude=-45, true_scale_latitude=70, globe=None)


qgis_colormap = Path("../data/speed-colorblind.txt")
cmap = qgis2cmap(qgis_colormap, name="speeds")

# Shade from the northwest, with the sun 45 degrees from horizontal
light_source = LightSource(azdeg=315, altdeg=45)

# sel = .sel(x=slice(-210000, 200000), y=slice(-2100000, -2400000))

jak_extent = {"x": slice(-420000, 60000), "y": slice(-1600000, -2350000)}

gris_target = obs_ds
jak_target = gris_target.sel(**jak_extent)

gris_ds = xr.open_dataset(Path("/Users/andy/Google Drive/My Drive/data/MCdataset/BedMachineGreenland-v5.nc"))
gris_surface = gris_ds["surface"].interp_like(gris_target)
gris_surface_hs = hillshade(gris_surface, zf=5)

#hs = hillshade(gris_dem_ds["surface"], zf=10)
#hs.plot(cmap="Greys_r", vmin=0, vmax=1, add_colorbar=False)

jak_obs_ds = obs_ds.sel(**jak_extent)

jak_speed = jak_obs_ds["v"]
jak_surface = gris_surface.interp_like(jak_target)
jak_hs = hillshade(gris_surface.interp_like(jak_target))


gris_speed = obs_ds["v"]
gris_hs = gris_surface_hs

norm = colors.Normalize(vmin=10.0, vmax=1500.0)
mapper = cm.ScalarMappable(norm=norm, cmap=cmap)

speed_img = mapper.to_rgba(gris_speed)

def get_extent(ds: xr.Dataset):
    return [ds["x"].values[0], ds["x"].values[-1], ds["y"].values[-1], ds["y"].values[0]]

def prepare(v, z, mapper):
    return (mapper.to_rgba(v), z)

extent = get_extent(jak_target)
v, z = prepare(jak_speed.to_numpy(), jak_surface.to_numpy(), mapper)
nx, ny = jak_speed.shape

extent = get_extent(gris_target)
v, z = prepare(gris_speed.to_numpy(), gris_surface.to_numpy(), mapper)
nx, ny = gris_speed.shape



In [None]:
import xarray as xr

In [None]:
ds = xr.open_dataset("https://its-live-data.s3-us-west-2.amazonaws.com/velocity_mosaic/landsat/v00.0/static/GRE_G0240_0000.nc", engine="h5netcdf", chunks="auto")

In [None]:
crs="EPSG:3413"

In [None]:
xs = [0.0, 3.3, 40]
ys = [-100, 2.0, -100]
geometry = [Point(p) for p in  zip(xs, ys)]

In [None]:
s = gp.GeoDataFrame(geometry=geometry, crs=crs)

In [None]:
from shapely import Point

In [None]:
Point?

In [None]:
geom = s.geometry

In [None]:
geom

In [None]:
s

In [None]:
geometry = gp.points_from_xy(xs, ys, crs=crs)

In [None]:
geometry.concave_hull

In [None]:
geometry.bounds

In [None]:
np.min(xs)