From fe52b0b1ebdbbca09291352b543147b6bdb273b1 Mon Sep 17 00:00:00 2001 From: Sean Gillies Date: Thu, 3 Dec 2020 14:52:12 -0700 Subject: [PATCH] Add new merge tool arguments to support resolution of #1867 --- CHANGES.txt | 7 ++++-- rasterio/__init__.py | 2 +- rasterio/merge.py | 52 +++++++++++++++++++++++++++++++++++------ rasterio/rio/merge.py | 29 +++-------------------- tests/test_rio_merge.py | 13 +++++++++++ 5 files changed, 67 insertions(+), 36 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index eb7d214a5..f3a2c3d8b 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,8 +1,11 @@ Changes ======= -1.2dev ------- +1.2a1 +----- + +- Add dst_path and dst_kwds parameters to rasterio's merge tool to allow + results to be written directly to a dataset (#1867). 1.1.8 (2020-10-20) ------------------ diff --git a/rasterio/__init__.py b/rasterio/__init__.py index cc87666ae..0498c6594 100644 --- a/rasterio/__init__.py +++ b/rasterio/__init__.py @@ -41,7 +41,7 @@ def emit(self, record): import rasterio.path __all__ = ['band', 'open', 'pad', 'Env'] -__version__ = "1.2dev" +__version__ = "1.2a1" __gdal_version__ = gdal_version() # Rasterio attaches NullHandler to the 'rasterio' logger and its diff --git a/rasterio/merge.py b/rasterio/merge.py index 796a1f1f3..3df3de0af 100644 --- a/rasterio/merge.py +++ b/rasterio/merge.py @@ -1,18 +1,17 @@ """Copy valid pixels from input files to an output file.""" from contextlib import contextmanager -from pathlib import Path import logging import math +from pathlib import Path import warnings import numpy as np import rasterio -from rasterio import windows -from rasterio.enums import Resampling from rasterio.compat import string_types from rasterio.enums import Resampling +from rasterio import windows from rasterio.transform import Affine @@ -21,9 +20,20 @@ MERGE_METHODS = ('first', 'last', 'min', 'max') -def merge(datasets, bounds=None, res=None, nodata=None, dtype=None, precision=10, - indexes=None, output_count=None, resampling=Resampling.nearest, - method='first'): +def merge( + datasets, + bounds=None, + res=None, + nodata=None, + dtype=None, + precision=10, + indexes=None, + output_count=None, + resampling=Resampling.nearest, + method="first", + dst_path=None, + dst_kwds=None, +): """Copy valid pixels from input files to an output file. All files must have the same number of bands, data type, and @@ -90,6 +100,11 @@ def function(old_data, new_data, old_nodata, new_nodata, index=None, roff=None, row offset in base array coff: int column offset in base array + dst_path : str or Pathlike, optional + Path of output dataset + dst_kwds : dict, optional + Dictionary of creation options and other paramters that will be + overlaid on the profile of the output dataset. Returns ------- @@ -124,6 +139,7 @@ def nullcontext(obj): dataset_opener = nullcontext with dataset_opener(datasets[0]) as first: + first_profile = first.profile first_res = first.res nodataval = first.nodatavals[0] dt = first.dtypes[0] @@ -135,6 +151,11 @@ def nullcontext(obj): else: src_count = len(indexes) + try: + first_colormap = first.colormap(1) + except ValueError: + first_colormap = None + if not output_count: output_count = src_count @@ -180,6 +201,16 @@ def nullcontext(obj): dt = dtype logger.debug("Set dtype: %s", dt) + out_profile = first_profile + out_profile.update(**(dst_kwds or {})) + + out_profile["transform"] = output_transform + out_profile["height"] = output_height + out_profile["width"] = output_width + out_profile["count"] = output_count + if nodata is not None: + out_profile["nodata"] = nodata + # create destination array dest = np.zeros((output_count, output_height, output_width), dtype=dt) @@ -296,4 +327,11 @@ def copyto(old_data, new_data, old_nodata, new_nodata, **kwargs): copyto(region, temp, region_nodata, temp_nodata, index=idx, roff=roff, coff=coff) - return dest, output_transform + if dst_path is None: + return dest, output_transform + + else: + with rasterio.open(dst_path, "w", **out_profile) as dst: + dst.write(dest) + if first_colormap: + dst.write_colormap(1, first_colormap) diff --git a/rasterio/rio/merge.py b/rasterio/rio/merge.py index 3955890ce..665a00fe2 100644 --- a/rasterio/rio/merge.py +++ b/rasterio/rio/merge.py @@ -3,7 +3,6 @@ import click -import rasterio from rasterio.enums import Resampling from rasterio.rio import options from rasterio.rio.helpers import resolve_inout @@ -56,7 +55,7 @@ def merge(ctx, files, output, driver, bounds, res, resampling, resampling = Resampling[resampling] with ctx.obj["env"]: - dest, output_transform = merge_tool( + merge_tool( files, bounds=bounds, res=res, @@ -64,28 +63,6 @@ def merge(ctx, files, output, driver, bounds, res, resampling, precision=precision, indexes=(bidx or None), resampling=resampling, + dst_path=output, + dst_kwds=creation_options, ) - - with rasterio.open(files[0]) as first: - profile = first.profile - profile["transform"] = output_transform - profile["height"] = dest.shape[1] - profile["width"] = dest.shape[2] - profile["count"] = dest.shape[0] - profile.pop("driver", None) - if driver: - profile["driver"] = driver - if nodata is not None: - profile["nodata"] = nodata - - profile.update(**creation_options) - - with rasterio.open(output, "w", **profile) as dst: - dst.write(dest) - - # uses the colormap in the first input raster. - try: - colormap = first.colormap(1) - dst.write_colormap(1, colormap) - except ValueError: - pass diff --git a/tests/test_rio_merge.py b/tests/test_rio_merge.py index d57b14fe0..0c9ed8cc9 100644 --- a/tests/test_rio_merge.py +++ b/tests/test_rio_merge.py @@ -606,6 +606,19 @@ def test_merge_pathlib_path(tiffs): merge(inputs, res=2) +def test_merge_output_dataset(tiffs, tmpdir): + """Write to an open dataset""" + inputs = [str(x) for x in tiffs.listdir()] + inputs.sort() + output_file = tmpdir.join("output.tif") + merge(inputs, res=2, dst_path=str(output_file), dst_kwds=dict(driver="PNG")) + + with rasterio.open(str(output_file)) as result: + assert result.count == 1 + assert result.driver == "PNG" + assert result.height == result.width == 2 + + @fixture(scope='function') def test_data_dir_resampling(tmpdir): kwargs = {