Skip to content

Commit

Permalink
Addition of bands(), map(), and partial()
Browse files Browse the repository at this point in the history
  • Loading branch information
Sean Gillies committed Feb 18, 2015
1 parent a47904c commit 361f4f5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
47 changes: 26 additions & 21 deletions rasterio/rio/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
from rasterio.rio.cli import cli


def get_band(inputs, d, i):
def get_bands(inputs, d, i=None):
"""Get a rasterio.Band object from calc's inputs"""
path = inputs[d] if d in dict(inputs) else inputs[int(d)-1][1]
return rasterio.band(rasterio.open(path), i)
if i:
return rasterio.band(rasterio.open(path), i)
else:
src = rasterio.open(path)
return [rasterio.band(src, i) for i in src.indexes]


@cli.command(short_help="Raster data calculator.")
Expand All @@ -30,13 +34,13 @@ def get_band(inputs, d, i):
required=True,
metavar="INPUTS... OUTPUT")
@click.option('--name', multiple=True,
help='Specify an input file with a unique short (alphas only) name '
'for use in commands like "a=tests/data/RGB.byte.tif".')
@click.option('--dtype',
type=click.Choice([
'ubyte', 'uint8', 'uint16', 'int16', 'uint32',
'int32', 'float32', 'float64']),
default='float64',
help='Specify an input file with a unique short (alphas only) '
'name for use in commands like '
'"a=tests/data/RGB.byte.tif".')
@click.option('--dtype',
type=click.Choice(['ubyte', 'uint8', 'uint16', 'int16', 'uint32',
'int32', 'float32', 'float64']),
default='float64',
help="Output data type (default: float64).")
@click.pass_context
def calc(ctx, command, files, name, dtype):
Expand Down Expand Up @@ -85,12 +89,11 @@ def calc(ctx, command, files, name, dtype):
logger = logging.getLogger('rio')

try:
with rasterio.drivers(CPL_DEBUG=verbosity>2):
with rasterio.drivers(CPL_DEBUG=verbosity > 2):
output = files[-1]

inputs = (
[tuple(n.split('=')) for n in name] +
[(None, n) for n in files[:-1]])
inputs = ([tuple(n.split('=')) for n in name] +
[(None, n) for n in files[:-1]])

with rasterio.open(inputs[0][1]) as first:
kwargs = first.meta
Expand All @@ -100,18 +103,20 @@ def calc(ctx, command, files, name, dtype):
ctxkwds = {}
for i, (name, path) in enumerate(inputs):
with rasterio.open(path) as src:
# Using the class method instead of instance method.
# Latter raises
# Using the class method instead of instance
# method. Latter raises
#
# TypeError: astype() got an unexpected keyword argument 'copy'
#
# possibly something to do with the instance being a masked
# array.
# TypeError: astype() got an unexpected keyword
# argument 'copy'
#
# possibly something to do with the instance being
# a masked array.
ctxkwds[name or '_i%d' % (i+1)] = np.ndarray.astype(
src.read(), 'float64', copy=False)
src.read(), 'float64', copy=False)

# Extend snuggs.
snuggs.func_map['band'] = lambda d, i: get_band(inputs, d, i)
snuggs.func_map['band'] = lambda d, i: get_bands(inputs, d, i)
snuggs.func_map['bands'] = lambda d: get_bands(inputs, d)
snuggs.func_map['fillnodata'] = lambda *args: fillnodata(*args)

res = snuggs.eval(command, **ctxkwds)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_rio_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,20 @@ def test_fillnodata(tmpdir):
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert round(data.mean(), 1) == 58.6


def test_fillnodata_map(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'(asarray (map fillnodata (bands 1)))',
'--dtype', 'uint8',
'tests/data/RGB.byte.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.count == 3
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert round(data.mean(), 1) == 58.6

0 comments on commit 361f4f5

Please sign in to comment.