In [None]:
import numpy as np
from numpy.testing import assert_array_almost_equal
import pandas as pd
import xarray as xr
from itertools import permutations
import pyproj
from typing import Union
from osgeo import gdal, ogr, osr
from pathlib import Path
from pypism.extract_profile import normal, tangential, Profile
from pypism.interpolation import InterpolationMatrix

In [None]:
def linear_function(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> np.ndarray:
    """A function linear in x, y, and z. Used to test our interpolation
    scheme."""
    return 10.0 + 0.01 * x + 0.02 * y + 0.03 + 0.04 * z


def create_dummy_input_dataset(F) -> xr.Dataset:
    """Create an input file for testing. Does not use unlimited
    dimensions, creates one time record only."""

    Mx = 88
    My = 152
    Mz = 11
    Mt = 1

    # use X and Y ranges corresponding to a grid covering Greenland
    x = np.linspace(-669650.0, 896350.0, Mx)
    y = np.linspace(-3362600.0, -644600.0, My)
    z = np.linspace(0, 4000.0, Mz)

    xx, yy = np.meshgrid(x, y)

    def write(dimensions: list):
        "Write test data to the file using given storage order."

        slices: dict[str, Any] = {
            "x": slice(0, Mx),
            "y": slice(0, My),
            "time": 0,
            "z": None,
        }
        dim_map = {"x": Mx, "y": My, "z": Mz, "time": Mt}

        # set indexes for all dimensions (z index will be re-set below)
        indexes: list[Any] = [Ellipsis] * len(dimensions)
        for k, d in enumerate(dimensions):
            indexes[k] = slices[d]

        # transpose 2D array if needed
        if dimensions.index("y") < dimensions.index("x"):

            def T(x):
                return x

        else:
            T = np.transpose

        dims = [dim_map[d] for d in dimensions]
        variable = np.zeros(dims)
        if "z" in dimensions:
            for k in range(Mz):
                indexes[dimensions.index("z")] = k
                variable[*indexes] = T(F(xx, yy, z[k]))
        else:
            variable[*indexes] = T(F(xx, yy, 0))

        return (dimensions, variable, {"long_name": name + " (make it long!)"})

    def P(x):
        return list(permutations(x))

    data_vars = {}
    for d in sorted(P(["x", "y"]) + P(["time", "x", "y"])):
        prefix = "test_2D_"
        name = prefix + "_".join(d)
        data_vars[name] = write(d)

    for d in sorted(P(["x", "y", "z"]) + P(["time", "x", "y", "z"])):
        prefix = "test_3D_"
        name = prefix + "_".join(d)
        data_vars[name] = write(d)

    ds = xr.Dataset(
        data_vars=data_vars,
        coords={
            "time": (["time"], [0], {}),
            "z": (["z"], z, {"_FillValue": False, "units": "m"}),
            "y": (
                ["y"],
                y,
                {
                    "_FillValue": False,
                    "units": "m",
                    "axis": "Y",
                    "standard_name": "projection_y_coordinate",
                },
            ),
            "x": (
                ["x"],
                x,
                {
                    "_FillValue": False,
                    "units": "m",
                    "axis": "X",
                    "standard_name": "projection_x_coordinate",
                },
            ),
        },
        attrs={"description": "Test data.", "proj": "epsg:3413", "proj4": "epsg:3413"},
    )
    return ds


def fixture_create_dummy_input_dataset_xyz():
    """
    Return a dummy dataset.
    """
    return create_dummy_input_dataset(linear_function)

def fixture_create_dummy_profile(dummy_input_dataset):
    """
    Return a dummy profile.

    """

    x = dummy_input_dataset["x"]
    y = dummy_input_dataset["y"]
    proj = dummy_input_dataset.attrs["proj"]
    projection = pyproj.Proj(str(proj))
    
    n_points = 4
    # move points slightly to make sure we can interpolate
    epsilon = 0.1
    x_profile = np.linspace(x[0] + epsilon, x[-1] - epsilon, n_points)
    y_profile = np.linspace(y[0] + epsilon, y[-1] - epsilon, n_points)
    x_center = 0.5 * (x_profile[0] + x_profile[-1])
    y_center = 0.5 * (y_profile[0] + y_profile[-1])

    lon, lat = projection(x_profile, y_profile, inverse=True)
    clon, clat = projection(x_center, y_center, inverse=True)

    flightline = 2
    glaciertype = 4
    flowtype = 2

    profile = Profile(
        0,
        "test profile",
        lat,
        lon,
        clat,
        clon,
        flightline,
        glaciertype,
        flowtype,
        projection,
    )
    
    return profile


In [None]:
dummy_input_dataset = fixture_create_dummy_input_dataset_xyz()
dummy_profile = fixture_create_dummy_profile(dummy_input_dataset)

In [None]:
def extract_profile(variable, profile, xdim: str = "x", ydim: str = "y", zdim: str = "z", tdim: str = "time"):
    """Extract values of a variable along a profile."""
    x = variable.coords[xdim].to_numpy()
    y = variable.coords[ydim].to_numpy()

    dim_length = dict(list(zip(variable.dims, variable.shape)))
    print(dim_length)

    def init_interpolation():
        """Initialize interpolation weights. Takes care of the transpose."""
        if variable.dims.index(ydim) < variable.dims.index(xdim):
            A = InterpolationMatrix(x, y, profile.x, profile.y)
            return A, slice(A.c_min, A.c_max + 1), slice(A.r_min, A.r_max + 1)
        else:
            A = InterpolationMatrix(y, x, profile.y, profile.x)
            return A, slice(A.r_min, A.r_max + 1), slice(A.c_min, A.c_max + 1)

    # try to get the matrix we (possibly) pre-computed earlier:
    try:
        # Check if we are extracting from the grid of the same shape
        # as before. This will make sure that we re-compute weights if
        # one variable is stored as (x,y) and a different as (y,x),
        # but will not catch grids that are of the same shape, but
        # with different extents and spacings. We'll worry about this
        # case later -- if we have to.
        if profile.grid_shape == variable.shape:
            A = profile.A
            x_slice = profile.x_slice
            y_slice = profile.y_slice
        else:
            A, x_slice, y_slice = init_interpolation()
    except AttributeError:
        A, x_slice, y_slice = init_interpolation()
        profile.A = A
        profile.x_slice = x_slice
        profile.y_slice = y_slice
        profile.grid_shape = variable.shape

    def read_subset(t=0, z=0):
        """Assemble the indexing tuple and get a sbset from a variable."""
        index = []
        indexes = {xdim: x_slice, ydim: y_slice, zdim: z, tdim: t}
        for dim in variable.dims:
            try:
                index.append(indexes[dim])
            except KeyError:
                index.append(Ellipsis)
        return variable[*index]

    n_points = len(profile.x)

    if tdim in variable.coords and zdim in variable.coords:
        dim_names = ["time", "profile", "z"]
        result = np.zeros((dim_length[tdim], n_points, dim_length[zdim]))
        for j in range(dim_length[tdim]):
            for k in range(dim_length[zdim]):
                result[j, :, k] = A.apply_to_subset(read_subset(t=j, z=k))
    elif tdim in variable.coords:
        dim_names = ["time", "profile"]
        result = np.zeros((dim_length[tdim], n_points))
        for j in range(dim_length[tdim]):
            result[j, :] = A.apply_to_subset(read_subset(t=j))
    elif zdim in variable.coords:
        dim_names = ["profile", "z"]
        result = np.zeros((n_points, dim_length[zdim]))
        for k in range(dim_length[zdim]):
            result[:, k] = A.apply_to_subset(read_subset(z=k))
    else:
        dim_names = ["profile"]
        result = A.apply_to_subset(read_subset())

    return result, dim_names


In [None]:
    n_points = len(dummy_profile.x)
    z = dummy_input_dataset["z"]

    desired_result = linear_function(dummy_profile.x, dummy_profile.y, 0.0)
    
    desired_3d_result = np.zeros((n_points, len(z)))
    for k, level in enumerate(z):
        desired_3d_result[:, k] = linear_function(dummy_profile.x, dummy_profile.y, level.to_numpy())

    def P(x):
        return list(permutations(x))

    # 2D variables
    for d in P(["x", "y"]) + P(["time", "x", "y"]):
        variable_name = "test_2D_" + "_".join(d)
        variable = dummy_input_dataset[variable_name]
        result, _ = extract_profile(variable, dummy_profile)

        assert_array_almost_equal(result, desired_result)
        
    # 3D variables
    for d in P(["x", "y", "z"]) + P(["time", "x", "y", "z"]):
        variable_name = "test_3D_" + "_".join(d)
        variable = dummy_input_dataset[variable_name]

        result, _ = extract_profile(variable, dummy_profile)

        assert_array_almost_equal(result, desired_result)


In [None]:
desired_result.shape

In [None]:
result[0]

In [None]:
np.squeeze(result).shape