In [None]:
import numpy as np
import pylab as plt
import xarray as xr
from pismragis.trajectories import compute_trajectory, compute_perturbation
from joblib import Parallel, delayed
import geopandas as gp
from pathlib import Path
from pismragis.processing import tqdm_joblib
import pandas as pd
from pyDOE import lhs
from scipy.stats.distributions import uniform
from tqdm.auto import tqdm

In [None]:
ds = xr.open_dataset("~/Google Drive/My Drive/data/ITS_LIVE/GRE_G0240_0000.nc")
data_url = Path("~/Google Drive/My Drive/data/ITS_LIVE/GRE_G0240_0000.nc")

In [None]:
ogr_url = Path("/Users/andy/Google Drive/My Drive/data/GreenlandFluxGatesAschwanden/greenland-flux-gates-jibneighbors.shp")

In [None]:
sigma = 1
VX = np.squeeze(ds["vx"].to_numpy())
VY = np.squeeze(ds["vy"].to_numpy())
VX_e = np.squeeze(ds["vx_err"].to_numpy())
VY_e = np.squeeze(ds["vy_err"].to_numpy())

x = ds["x"].to_numpy()
y = ds["y"].to_numpy()
nx = len(x)
ny = len(y)



In [None]:
n_draw_samples = 20
unif_sample = lhs(2, n_draw_samples)

with tqdm_joblib(tqdm(desc="Processing perturbations", total=n_draw_samples, position=0)) as progress_bar:
    all_perturb_glaciers = Parallel(n_jobs=10)(
        delayed(compute_perturbation)(data_url, ogr_url,
                                       perturbation=perturb, sample=unif_sample[perturb, :], 
                                       total_time=1_000, dt=1, reverse=True)
                for perturb in range(n_draw_samples))
    del progress_bar

all_perturb_glaciers = pd.concat(all_perturb_glaciers).reset_index(drop=True)
# all_perturb_glaciers.to_file("all_perturb_traj.gpkg")

In [None]:
ds.argmax?

In [None]:
n_perturbations = 10
unif_sample = lhs(2, n_perturbations)


In [None]:
from pathlib import Path
from typing import Tuple, Union

import geopandas as gp
import numpy as np
import pandas as pd
from geopandas import GeoDataFrame
from numpy import ndarray
from osgeo import ogr, osr
from shapely import Point
from tqdm.auto import tqdm
from xarray import DataArray


In [None]:
def compute_perturbation(
    data_url: Union[str, Path],
    ogr_url: Union[str, Path],
    perturbation: int = 0,
    sample: Union[list, ndarray] = [0.5, 0.5],
    sigma: float = 1,
    total_time: float = 10_000,
    dt: float = 1,
    reverse: bool = False,
) -> GeoDataFrame:
    """
    Compute a perturbed trajectory.

    It appears OGR objects cannot be pickled by joblib hence we load it here.

    Parameters
    ----------
    url : string or pathlib.Path
        Path to an ogr data set
    VX_min : numpy.ndarray or xarray.DataArray
        Minimum
    VX_min : dict-like, optional
        Another mapping in similar form as the `data_vars` argument,
        except the each item is saved on the dataset as a "coordinate".
        These variables have an associated meaning: they describe
        constant/fixed/independent quantities, unlike the
        varying/measured/dependent quantities that belong in
        `variables`. Coordinates values may be given by 1-dimensional
        arrays or scalars, in which case `dims` do not need to be
        supplied: 1D arrays will be assumed to give index values along
        the dimension with the same name.

        The following notations are accepted:

        - mapping {coord name: DataArray}
        - mapping {coord name: Variable}
        - mapping {coord name: (dimension name, array-like)}
        - mapping {coord name: (tuple of dimension names, array-like)}
        - mapping {dimension name: array-like}
          (the dimension name is implicitly set to be the same as the
          coord name)

        The last notation implies that the coord name is the same as
        the dimension name.

    attrs : dict-like, optional
        Global attributes to save on this dataset.

    Examples
    --------
    Create data:

    >>> np.random.seed(0)
    >>> temperature = 15 + 8 * np.random.randn(2, 2, 3)
    >>> precipitation = 10 * np.random.rand(2, 2, 3)
    >>> lon = [[-99.83, -99.32], [-99.79, -99.23]]
    >>> lat = [[42.25, 42.21], [42.63, 42.59]]
    >>> time = pd.date_range("2014-09-06", periods=3)
    >>> reference_time = pd.Timestamp("2014-09-05")

    Initialize a dataset with multiple dimensions:

    >>> ds = xr.Dataset(
    ...     data_vars=dict(
    ...         temperature=(["x", "y", "time"], temperature),
    ...         precipitation=(["x", "y", "time"], precipitation),
    ...     ),
    ...     coords=dict(
    ...         lon=(["x", "y"], lon),
    ...         lat=(["x", "y"], lat),
    ...         time=time,
    ...         reference_time=reference_time,
    ...     ),
    ...     attrs=dict(description="Weather related data."),
    ... )
    >>> ds
    <xarray.Dataset>
    Dimensions:         (x: 2, y: 2, time: 3)
    Coordinates:
        lon             (x, y) float64 -99.83 -99.32 -99.79 -99.23
        lat             (x, y) float64 42.25 42.21 42.63 42.59
      * time            (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08
        reference_time  datetime64[ns] 2014-09-05
    Dimensions without coordinates: x, y
    Data variables:
        temperature     (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63
        precipitation   (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805
    Attributes:
        description:  Weather related data.

    Find out where the coldest temperature was and what values the
    other variables had:

    >>> ds.isel(ds.temperature.argmin(...))
    <xarray.Dataset>
    Dimensions:         ()
    Coordinates:
        lon             float64 -99.32
        lat             float64 42.21
        time            datetime64[ns] 2014-09-08
        reference_time  datetime64[ns] 2014-09-05
    Data variables:
        temperature     float64 7.182
        precipitation   float64 8.326
    Attributes:
        description:  Weather related data.


    """

    ds = xr.open_dataset(data_url, chunks="auto")
    VX = ds["vx"]
    VY = ds["vy"]
    VX_e = ds["vx_err"]
    VY_e = ds["vy_err"]
    x = ds["x"]
    y = ds["y"]

#     VX = np.squeeze(ds["vx"].to_numpy())
#     VY = np.squeeze(ds["vy"].to_numpy())
#     VX_e = np.squeeze(ds["vx_err"].to_numpy())
#     VY_e = np.squeeze(ds["vy_err"].to_numpy())
#     x = ds["x"].to_numpy()
#     y = ds["y"].to_numpy()
    
    Vx, Vy = get_perturbed_velocities(VX, VY, VX_e, VY_e, sample=sample, sigma=sigma)
    ogr.UseExceptions()
    if isinstance(ogr_url, Path):
        ogr_url = str(ogr_url.absolute())
    in_ds = ogr.Open(ogr_url)

    layer = in_ds.GetLayer(0)
    layer_type = ogr.GeometryTypeToName(layer.GetGeomType())
    srs = layer.GetSpatialRef()
    srs_geo = osr.SpatialReference()
    srs_geo.ImportFromEPSG(3413)

    all_glaciers = []
    progress = tqdm(enumerate(layer), total=len(layer), leave=False)
    for ft, feature in progress:
        geometry = feature.GetGeometryRef()
        geometry.TransformTo(srs_geo)
        points = geometry.GetPoints()
        points = [Point(p) for p in points]
        attrs = feature.items()
        attrs["perturbation"] = perturbation
        glacier_name = attrs["name"]
        progress.set_description(f"""Processing {glacier_name}""")
        trajs = []
        for p in points:
            traj, _ = compute_trajectory(
                p, Vx, Vy, x, y, total_time=total_time, dt=dt, reverse=reverse
            )
            trajs.append(traj)
        df = trajectories_to_geopandas(trajs, Vx, Vy, x, y, attrs=attrs)
        all_glaciers.append(df)
    return pd.concat(all_glaciers)


In [None]:
def get_perturbed_velocities(VX, VY, VX_e, VY_e, sample, sigma: float = 1.0):

    VX_min, VX_max = VX - sigma * VX_e, VX + sigma * VX_e
    VY_min, VY_max = VY - sigma * VY_e, VY + sigma * VY_e
    
    Vx = VX_min + sample[0] * (VX_max - VX_min)
    Vy = VY_min + sample[1] * (VY_max - VY_min)
    
    return Vx, Vy


In [None]:
from pismragis.trajectories import trajectories_to_geopandas

In [None]:
%time all_perturb_glaciers = Parallel(n_jobs=10)(delayed(compute_perturbation)(data_url, ogr_url, perturbation=perturb, sample=unif_sample[perturb, :], total_time=1_000, dt=1, reverse=True) for perturb in range(n_perturbations))


In [None]:
%time all_perturb_glaciers = Parallel(n_jobs=10)(delayed(compute_perturbation)(data_url, ogr_url, perturbation=perturb, sample=unif_sample[perturb, :], total_time=1_000, dt=1, reverse=True) for perturb in range(n_perturbations))

In [1]:
%load_ext memory_profiler

In [2]:
%load_ext line_profiler

In [None]:
# You need to load the extension first
%load_ext viztracer

In [None]:
%time all_perturb_glaciers = Parallel(n_jobs=10)(delayed(compute_perturbation)(data_url, ogr_url, perturbation=perturb, sample=unif_sample[perturb, :], total_time=1_000, dt=1, reverse=True) for perturb in range(n_draw_samples))

In [5]:
run test.py

KeyboardInterrupt: 

In [16]:
import numpy as np
import pylab as plt
import xarray as xr
from pismragis.trajectories import compute_trajectory, compute_perturbation
from joblib import Parallel, delayed
import geopandas as gp
from pathlib import Path
from pismragis.processing import tqdm_joblib
import pandas as pd
from pyDOE import lhs
from scipy.stats.distributions import uniform
from tqdm.auto import tqdm

from pathlib import Path
from typing import Tuple, Union

import geopandas as gp
import numpy as np
import pandas as pd
from geopandas import GeoDataFrame
from numpy import ndarray
from osgeo import ogr, osr
from shapely import Point
from tqdm.auto import tqdm
from xarray import DataArray
from pismragis.trajectories import trajectories_to_geopandas


def compute_perturbation(
    data_url: Union[str, Path],
    ogr_url: Union[str, Path],
    perturbation: int = 0,
    sample: Union[list, ndarray] = [0.5, 0.5],
    sigma: float = 1,
    total_time: float = 10_000,
    dt: float = 1,
    reverse: bool = False,
) -> GeoDataFrame:
    """
    Compute a perturbed trajectory.

    It appears OGR objects cannot be pickled by joblib hence we load it here.

    Parameters
    ----------
    url : string or pathlib.Path
        Path to an ogr data set
    VX_min : numpy.ndarray or xarray.DataArray
        Minimum
    VX_min : dict-like, optional
        Another mapping in similar form as the `data_vars` argument,
        except the each item is saved on the dataset as a "coordinate".
        These variables have an associated meaning: they describe
        constant/fixed/independent quantities, unlike the
        varying/measured/dependent quantities that belong in
        `variables`. Coordinates values may be given by 1-dimensional
        arrays or scalars, in which case `dims` do not need to be
        supplied: 1D arrays will be assumed to give index values along
        the dimension with the same name.

        The following notations are accepted:

        - mapping {coord name: DataArray}
        - mapping {coord name: Variable}
        - mapping {coord name: (dimension name, array-like)}
        - mapping {coord name: (tuple of dimension names, array-like)}
        - mapping {dimension name: array-like}
          (the dimension name is implicitly set to be the same as the
          coord name)

        The last notation implies that the coord name is the same as
        the dimension name.

    attrs : dict-like, optional
        Global attributes to save on this dataset.

    Examples
    --------
    Create data:

    >>> np.random.seed(0)
    >>> temperature = 15 + 8 * np.random.randn(2, 2, 3)
    >>> precipitation = 10 * np.random.rand(2, 2, 3)
    >>> lon = [[-99.83, -99.32], [-99.79, -99.23]]
    >>> lat = [[42.25, 42.21], [42.63, 42.59]]
    >>> time = pd.date_range("2014-09-06", periods=3)
    >>> reference_time = pd.Timestamp("2014-09-05")

    Initialize a dataset with multiple dimensions:

    >>> ds = xr.Dataset(
    ...     data_vars=dict(
    ...         temperature=(["x", "y", "time"], temperature),
    ...         precipitation=(["x", "y", "time"], precipitation),
    ...     ),
    ...     coords=dict(
    ...         lon=(["x", "y"], lon),
    ...         lat=(["x", "y"], lat),
    ...         time=time,
    ...         reference_time=reference_time,
    ...     ),
    ...     attrs=dict(description="Weather related data."),
    ... )
    >>> ds
    <xarray.Dataset>
    Dimensions:         (x: 2, y: 2, time: 3)
    Coordinates:
        lon             (x, y) float64 -99.83 -99.32 -99.79 -99.23
        lat             (x, y) float64 42.25 42.21 42.63 42.59
      * time            (time) datetime64[ns] 2014-09-06 2014-09-07 2014-09-08
        reference_time  datetime64[ns] 2014-09-05
    Dimensions without coordinates: x, y
    Data variables:
        temperature     (x, y, time) float64 29.11 18.2 22.83 ... 18.28 16.15 26.63
        precipitation   (x, y, time) float64 5.68 9.256 0.7104 ... 7.992 4.615 7.805
    Attributes:
        description:  Weather related data.

    Find out where the coldest temperature was and what values the
    other variables had:

    >>> ds.isel(ds.temperature.argmin(...))
    <xarray.Dataset>
    Dimensions:         ()
    Coordinates:
        lon             float64 -99.32
        lat             float64 42.21
        time            datetime64[ns] 2014-09-08
        reference_time  datetime64[ns] 2014-09-05
    Data variables:
        temperature     float64 7.182
        precipitation   float64 8.326
    Attributes:
        description:  Weather related data.


    """

    ds = xr.open_dataset(data_url)

#     VX = ds["vx"]
#     VY = ds["vy"]
#     VX_e = ds["vx_err"]
#     VY_e = ds["vy_err"]
#     x = ds["x"]
#     y = ds["y"]

    VX = np.squeeze(ds["vx"].to_numpy())
    VY = np.squeeze(ds["vy"].to_numpy())
    VX_e = np.squeeze(ds["vx_err"].to_numpy())
    VY_e = np.squeeze(ds["vy_err"].to_numpy())
    x = ds["x"].to_numpy()
    y = ds["y"].to_numpy()

    Vx, Vy = get_perturbed_velocities(VX, VY, VX_e, VY_e, sample=sample, sigma=sigma)
    ogr.UseExceptions()
    if isinstance(ogr_url, Path):
        ogr_url = str(ogr_url.absolute())
    in_ds = ogr.Open(ogr_url)

    layer = in_ds.GetLayer(0)
    layer_type = ogr.GeometryTypeToName(layer.GetGeomType())
    srs = layer.GetSpatialRef()
    srs_geo = osr.SpatialReference()
    srs_geo.ImportFromEPSG(3413)

    all_glaciers = []
    progress = tqdm(enumerate(layer), total=len(layer), leave=False)
    for ft, feature in progress:
        geometry = feature.GetGeometryRef()
        geometry.TransformTo(srs_geo)
        points = geometry.GetPoints()
        points = [Point(p) for p in points]
        attrs = feature.items()
        attrs["perturbation"] = perturbation
        glacier_name = attrs["name"]
        progress.set_description(f"""Processing {glacier_name}""")
        trajs = []
        for p in points:
            traj, _ = compute_trajectory(
                p, Vx, Vy, x, y, total_time=total_time, dt=dt, reverse=reverse
            )
            trajs.append(traj)
        df = trajectories_to_geopandas(trajs, Vx, Vy, x, y, attrs=attrs)
        all_glaciers.append(df)
    return pd.concat(all_glaciers)


def get_perturbed_velocities(VX, VY, VX_e, VY_e, sample, sigma: float = 1.0):
    VX_min, VX_max = VX - sigma * VX_e, VX + sigma * VX_e
    VY_min, VY_max = VY - sigma * VY_e, VY + sigma * VY_e

    Vx = VX_min + sample[0] * (VX_max - VX_min)
    Vy = VY_min + sample[1] * (VY_max - VY_min)

    return Vx, Vy



In [8]:
def run_parallel():
    data_url = Path("~/Google Drive/My Drive/data/ITS_LIVE/GRE_G0240_0000.nc")
    ogr_url = Path(
        "/Users/andy/Google Drive/My Drive/data/GreenlandFluxGatesAschwanden/greenland-flux-gates-jibneighbors.shp"
    )

    n_perturbations = 10
    unif_sample = lhs(2, n_perturbations)

    Parallel(n_jobs=10)(
        delayed(compute_perturbation)(
            data_url,
            ogr_url,
            perturbation=perturb,
            sample=unif_sample[perturb, :],
            total_time=1_000,
            dt=1,
            reverse=True,
        )
        for perturb in range(n_perturbations)
    )


def run_serial():
    data_url = Path("~/Google Drive/My Drive/data/ITS_LIVE/GRE_G0240_0000.nc")
    ogr_url = Path(
        "/Users/andy/Google Drive/My Drive/data/GreenlandFluxGatesAschwanden/greenland-flux-gates-jibneighbors.shp"
    )

    n_perturbations = 10
    unif_sample = lhs(2, n_perturbations)

    perturb = 0
    compute_perturbation(
        data_url,
        ogr_url,
        perturbation=perturb,
        sample=unif_sample[perturb, :],
        total_time=1_000,
        dt=1,
        reverse=True,
    )


In [17]:
%timeit run_serial()

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

17.9 s ± 240 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
%timeit run_serial()

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

  0%|          | 0/6 [00:00<?, ?it/s]

22.1 s ± 220 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
