Skip to content

Commit

Permalink
Add dataset ref by name feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
sgillies committed Feb 10, 2015
1 parent aa5c81f commit c6d8bfe
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 52 deletions.
126 changes: 74 additions & 52 deletions rasterio/rio/calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

@cli.command(short_help="Raster data calculator.")
@click.argument('command')
@files_inout_arg
@click.argument(
'files',
nargs=-1,
type=click.Path(resolve_path=False),
required=True,
metavar="INPUTS... OUTPUT")
@click.option('--dtype',
type=click.Choice([
'ubyte', 'uint8', 'uint16', 'int16', 'uint32',
Expand Down Expand Up @@ -58,57 +63,74 @@ def calc(ctx, command, files, dtype):
kwargs['transform'] = kwargs.pop('affine')
kwargs['dtype'] = dtype

# Using the class method instead of instance method.
# Latter raises
# TypeError: astype() got an unexpected keyword argument 'copy'
# Possibly something to do with the instance being a masked
# array.
sources = np.ma.asanyarray([np.ndarray.astype(
rasterio.open(path).read(),
'float64',
copy=False
) for path in files])

parts = command.split(';')
_prev = None

for part in filter(lambda p: p.strip(), parts):

# TODO: implement a real parser for calc expressions,
# perhaps using numexpr's parser as a guide, instead
# eval'ing any string.

cmd = re.sub(
r'{(\d)\s*,\s*(\d)}',
lambda m: 'sources[%d,%d]' % (
int(m.group(1))-1, int(m.group(2))-1),
part)

# Translates, eg, '{1}' to 'sources[0]'.
cmd = re.sub(
r'{(\d)}',
lambda m: 'sources[%d]' % (int(m.group(1))-1),
cmd)

# Translate '{}' to '_prev'.
cmd = re.sub(r'{}', '_prev', cmd)

logger.debug("Translated cmd: %r", cmd)

res = eval(cmd)
_prev = res

if isinstance(res, tuple) or len(res.shape) == 3:
results = np.asanyarray([
np.ndarray.astype(r, dtype, copy=False
) for r in res])
else:
results = np.asanyarray(
[np.ndarray.astype(res, dtype, copy=False)])

kwargs['count'] = results.shape[0]
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)
names = []
sources = []
for path in files:
with rasterio.open(path) as src:
names.append(src.name)
# Using the class method instead of instance method.
# Latter raises
# TypeError: astype() got an unexpected keyword argument 'copy'
# Possibly something to do with the instance being a masked
# array.
sources.append(
np.ndarray.astype(src.read(), 'float64', copy=False))

#sources = np.ma.asanyarray([s for s in sources])

parts = command.split(';')
_prev = None

def cmd_sources(match):
text = match.group(1)
parts = text.split(',')
v = parts.pop(0)
if v in names:
a = names.index(v)
s = 'sources[%d]' % a
if parts:
s += '[%d]' % (int(parts.pop(0)) - 1)
return s

for part in filter(lambda p: p.strip(), parts):

# TODO: implement a real parser for calc expressions,
# perhaps using numexpr's parser as a guide, instead
# eval'ing any string.

# Translate '{}' to '_prev'.
cmd = re.sub(r'{}', '_prev', part)

cmd = re.sub(
r'{(\d+),(\d+)}',
lambda m: 'sources[%d][%d]' % (
int(m.group(1))-1,
int(m.group(2))-1),
cmd)

cmd = re.sub(
r'{(\d+)}',
lambda m: 'sources[%d]' % (int(m.group(1))-1),
cmd)

cmd = re.sub(r'{(.+)}', cmd_sources, cmd)

logger.debug("Translated cmd: %r", cmd)

res = eval(cmd)
_prev = res

if isinstance(res, tuple) or len(res.shape) == 3:
results = np.asanyarray([
np.ndarray.astype(r, dtype, copy=False
) for r in res])
else:
results = np.asanyarray(
[np.ndarray.astype(res, dtype, copy=False)])

kwargs['count'] = results.shape[0]
with rasterio.open(output, 'w', **kwargs) as dst:
dst.write(results)

sys.exit(0)
except Exception:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_rio_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,22 @@ def test_singleband_calc(tmpdir):
assert data.min() == 125


def test_singleband_calc_by_name(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'0.10*{tests/data/shade.tif,1} + 125',
'tests/data/shade.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.count == 1
assert src.meta['dtype'] == 'float64'
data = src.read()
assert data.min() == 125


def test_parts_calc(tmpdir):
# Producing an RGB output from the hill shade.
# Red band has bumped up values. Other bands are unchanged.
Expand Down Expand Up @@ -137,3 +153,20 @@ def test_copy_rgb_tempval(tmpdir):
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert round(data.mean(), 1) == 60.6


def test_copy_rgb_by_name(tmpdir):
outfile = str(tmpdir.join('out.tif'))
runner = CliRunner()
result = runner.invoke(calc, [
'{tests/data/RGB.byte.tif}',
'--dtype', 'uint8',
'tests/data/RGB.byte.tif',
outfile],
catch_exceptions=False)
assert result.exit_code == 0
with rasterio.open(outfile) as src:
assert src.count == 3
assert src.meta['dtype'] == 'uint8'
data = src.read()
assert round(data.mean(), 1) == 60.6

0 comments on commit c6d8bfe

Please sign in to comment.