Skip to content

Commit

Permalink
grdtrack: Fix the bug when profile is given (GenericMappingTools#1867)
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Ji <23487320+weiji14@users.noreply.github.com>
  • Loading branch information
2 people authored and Josh Sixsmith committed Dec 21, 2022
1 parent 9a7badb commit 17ac269
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 15 deletions.
73 changes: 58 additions & 15 deletions pygmt/src/grdtrack.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
"""
grdtrack - Sample grids at specified (x,y) locations.
"""
import warnings

import pandas as pd
import xarray as xr
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
Expand All @@ -11,6 +14,7 @@
kwargs_to_strings,
use_alias,
)
from pygmt.src.which import which

__doctest_skip__ = ["grdtrack"]

Expand Down Expand Up @@ -43,7 +47,7 @@
w="wrap",
)
@kwargs_to_strings(R="sequence", S="sequence", i="sequence_comma", o="sequence_comma")
def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
def grdtrack(grid, points=None, newcolname=None, outfile=None, **kwargs):
r"""
Sample grids at specified (x,y) locations.
Expand All @@ -67,14 +71,14 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
Parameters
----------
points : str or {table-like}
Pass in either a file name to an ASCII data table, a 2D
{table-classes}.
grid : xarray.DataArray or str
Gridded array from which to sample values from, or a filename (netcdf
format).
points : str or {table-like}
Pass in either a file name to an ASCII data table, a 2D
{table-classes}.
newcolname : str
Required if ``points`` is a :class:`pandas.DataFrame`. The name for the
new column in the track :class:`pandas.DataFrame` table where the
Expand Down Expand Up @@ -283,26 +287,65 @@ def grdtrack(points, grid, newcolname=None, outfile=None, **kwargs):
... points=points, grid=grid, newcolname="bathymetry"
... )
"""
# pylint: disable=too-many-branches
if points is not None and kwargs.get("E") is not None:
raise GMTInvalidInput("Can't set both 'points' and 'profile'.")

if points is None and kwargs.get("E") is None:
raise GMTInvalidInput("Must give 'points' or set 'profile'.")

if hasattr(points, "columns") and newcolname is None:
raise GMTInvalidInput("Please pass in a str to 'newcolname'")

# Backward compatibility with old parameter order "points, grid".
# deprecated_version="0.7.0", remove_version="v0.9.0"
is_a_grid = True
if not isinstance(grid, (xr.DataArray, str)):
is_a_grid = False
elif isinstance(grid, str):
try:
xr.open_dataarray(which(grid, download="a"), engine="netcdf4").close()
is_a_grid = True
except (ValueError, OSError):
is_a_grid = False
if not is_a_grid:
msg = (
"Positional parameters 'points, grid' of pygmt.grdtrack() has changed "
"to 'grid, points=None' since v0.7.0. It's likely that you're NOT "
"passing a valid grid as the first positional argument or "
"are passing an invalid grid to the 'grid' parameter. "
"Please check the order of arguments with the latest documentation. "
"This warning will be removed in v0.9.0."
)
grid, points = points, grid
warnings.warn(msg, category=FutureWarning, stacklevel=1)

with GMTTempFile(suffix=".csv") as tmpfile:
with Session() as lib:
# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(check_kind="vector", data=points)
# Store the xarray.DataArray grid in virtualfile
grid_context = lib.virtualfile_from_data(check_kind="raster", data=grid)

# Run grdtrack on the temporary (csv) points table
# and (netcdf) grid virtualfile
with table_context as csvfile:
with grid_context as grdfile:
kwargs.update({"G": grdfile})
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name
with grid_context as grdfile:
kwargs.update({"G": grdfile})
if outfile is None: # Output to tmpfile if outfile is not set
outfile = tmpfile.name

if points is not None:
# Choose how data will be passed into the module
table_context = lib.virtualfile_from_data(
check_kind="vector", data=points
)
with table_context as csvfile:
lib.call_module(
module="grdtrack",
args=build_arg_string(
kwargs, infile=csvfile, outfile=outfile
),
)
else:
lib.call_module(
module="grdtrack",
args=build_arg_string(kwargs, infile=csvfile, outfile=outfile),
args=build_arg_string(kwargs, outfile=outfile),
)

# Read temporary csv output to a pandas table
Expand Down
54 changes: 54 additions & 0 deletions pygmt/tests/test_grdtrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,29 @@ def test_grdtrack_input_csvfile_and_ncfile_to_dataframe(expected_array):
npt.assert_allclose(np.array(output), expected_array)


def test_grdtrack_profile(dataarray):
"""
Run grdtrack by passing a profile.
"""
output = grdtrack(grid=dataarray, profile="-51/-17/-54/-19")
assert isinstance(output, pd.DataFrame)
npt.assert_allclose(
np.array(output),
np.array(
[
[-51.0, -17.0, 669.671875],
[-51.42430204, -17.28838525, 847.40745877],
[-51.85009439, -17.57598444, 885.30534844],
[-52.27733766, -17.86273467, 829.85423488],
[-52.70599151, -18.14857333, 776.83702212],
[-53.13601473, -18.43343819, 631.07867839],
[-53.56736521, -18.7172675, 504.28037216],
[-54.0, -19.0, 486.10351562],
]
),
)


def test_grdtrack_wrong_kind_of_points_input(dataarray, dataframe):
"""
Run grdtrack using points input that is not a pandas.DataFrame (matrix) or
Expand Down Expand Up @@ -137,3 +160,34 @@ def test_grdtrack_without_outfile_setting(dataarray, dataframe):
"""
with pytest.raises(GMTInvalidInput):
grdtrack(points=dataframe, grid=dataarray)


def test_grdtrack_no_points_and_profile(dataarray):
"""
Run grdtrack but don't set 'points' and 'profile'.
"""
with pytest.raises(GMTInvalidInput):
grdtrack(grid=dataarray)


def test_grdtrack_set_points_and_profile(dataarray, dataframe):
"""
Run grdtrack but set both 'points' and 'profile'.
"""
with pytest.raises(GMTInvalidInput):
grdtrack(grid=dataarray, points=dataframe, profile="BL/TR")


def test_grdtrack_old_parameter_order(dataframe, dataarray, expected_array):
"""
Run grdtrack with the old parameter order 'points, grid'.
This test should be removed in v0.9.0.
"""
for points in (POINTS_DATA, dataframe):
for grid in ("@static_earth_relief.nc", dataarray):
with pytest.warns(expected_warning=FutureWarning) as record:
output = grdtrack(points, grid)
assert len(record) == 1
assert isinstance(output, pd.DataFrame)
npt.assert_allclose(np.array(output), expected_array)

0 comments on commit 17ac269

Please sign in to comment.