Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use rasterio to save geotiffs when available #252

Merged
merged 5 commits into from
Apr 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion satpy/tests/writer_tests/test_geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def test_float_write(self):
datasets = self._get_test_datasets()
w = GeoTIFFWriter()
w.save_datasets(datasets,
floating_point=True,
dtype=np.float32,
enhancement_config=False,
base_dir=self.base_dir)

Expand Down
242 changes: 121 additions & 121 deletions satpy/writers/geotiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,20 @@
LOG = logging.getLogger(__name__)


# Map numpy data types to GDAL data types
NP2GDAL = {
np.float32: gdal.GDT_Float32,
np.float64: gdal.GDT_Float64,
np.uint8: gdal.GDT_Byte,
np.uint16: gdal.GDT_UInt16,
np.uint32: gdal.GDT_UInt32,
np.int16: gdal.GDT_Int16,
np.int32: gdal.GDT_Int32,
np.complex64: gdal.GDT_CFloat32,
np.complex128: gdal.GDT_CFloat64,
}


class GeoTIFFWriter(ImageWriter):

"""Writer to save GeoTIFF images.
Expand All @@ -45,8 +59,8 @@ class GeoTIFFWriter(ImageWriter):

Un-enhanced float geotiff with NaN for fill values:

scn.save_datasets(writer='geotiff', floating_point=True,
enhancement_config=False, fill_value=np.nan)
scn.save_datasets(writer='geotiff', dtype=np.float32,
enhancement_config=False)

"""

Expand All @@ -73,14 +87,12 @@ class GeoTIFFWriter(ImageWriter):
"pixeltype",
"copy_src_overviews", )

def __init__(self, floating_point=False, tags=None, **kwargs):
def __init__(self, dtype=None, tags=None, **kwargs):
ImageWriter.__init__(self,
default_config_filename="writers/geotiff.yaml",
**kwargs)

self.floating_point = bool(self.info.get(
"floating_point", None) if floating_point is None else
floating_point)
self.dtype = self.info.get("dtype") if dtype is None else dtype
self.tags = self.info.get("tags",
None) if tags is None else tags
if self.tags is None:
Expand All @@ -100,97 +112,77 @@ def separate_init_kwargs(cls, kwargs):
# FUTURE: Don't pass Scene.save_datasets kwargs to init and here
init_kwargs, kwargs = super(GeoTIFFWriter, cls).separate_init_kwargs(
kwargs)
for kw in ['floating_point', 'tags']:
for kw in ['dtype', 'tags']:
if kw in kwargs:
init_kwargs[kw] = kwargs.pop(kw)

return init_kwargs, kwargs

def _gdal_write_datasets(self, dst_ds, datasets, opacity):
"""Write *datasets* in a gdal raster structure *dts_ds*, using
*opacity* as alpha value for valid data, and *fill_value*.
"""
def _write_array(bnd, chn):
bnd.WriteArray(chn.values)

# queue up data writes so we don't waste computation time
delayed = []
def _gdal_write_datasets(self, dst_ds, datasets):
"""Write datasets in a gdal raster structure dts_ds"""
for i, band in enumerate(datasets['bands']):
chn = datasets.sel(bands=band)
bnd = dst_ds.GetRasterBand(i + 1)
bnd.SetNoDataValue(0)
delay = dask.delayed(_write_array)(bnd, chn)
delayed.append(delay)
dask.compute(*delayed)

def _create_file(self, filename, img, gformat, g_opts, opacity,
datasets, mode):
raster = gdal.GetDriverByName("GTiff")

if mode == "L":
dst_ds = raster.Create(filename, img.width, img.height, 1,
gformat, g_opts)
self._gdal_write_datasets(dst_ds, datasets, opacity)
elif mode == "LA":
g_opts.append("ALPHA=YES")
dst_ds = raster.Create(filename, img.width, img.height, 2, gformat,
g_opts)
self._gdal_write_datasets(dst_ds, datasets, datasets)
elif mode == "RGB":
dst_ds = raster.Create(filename, img.width, img.height, 3,
gformat, g_opts)
self._gdal_write_datasets(dst_ds, datasets, datasets)

elif mode == "RGBA":
g_opts.append("ALPHA=YES")
dst_ds = raster.Create(filename, img.width, img.height, 4, gformat,
g_opts)

self._gdal_write_datasets(dst_ds, datasets, datasets)
else:
raise NotImplementedError(
"Saving to GeoTIFF using image mode %s is not implemented." %
mode)

# Create raster GeoTransform based on upper left corner and pixel
# resolution ... if not overwritten by argument geotransform.
if "area" not in img.data.attrs:
LOG.warning("No 'area' metadata found in image")
else:
area = img.data.attrs["area"]
bnd.WriteArray(chn.values)

def _gdal_write_geo(self, dst_ds, area):
try:
geotransform = [area.area_extent[0], area.pixel_size_x, 0,
area.area_extent[3], 0, -area.pixel_size_y]
dst_ds.SetGeoTransform(geotransform)
srs = osr.SpatialReference()

srs.ImportFromProj4(area.proj4_string)
srs.SetProjCS(area.proj_id)
try:
srs.SetWellKnownGeogCS(area.proj_dict['ellps'])
except KeyError:
pass
try:
geotransform = [area.area_extent[0], area.pixel_size_x, 0,
area.area_extent[3], 0, -area.pixel_size_y]
dst_ds.SetGeoTransform(geotransform)
srs = osr.SpatialReference()

srs.ImportFromProj4(area.proj4_string)
srs.SetProjCS(area.proj_id)
try:
srs.SetWellKnownGeogCS(area.proj_dict['ellps'])
except KeyError:
pass
try:
# Check for epsg code.
srs.ImportFromEPSG(int(
area.proj_dict['init'].lower().split('epsg:')[1]))
except (KeyError, IndexError):
pass
srs = srs.ExportToWkt()
dst_ds.SetProjection(srs)
except AttributeError:
LOG.warning(
"Can't save geographic information to geotiff, unsupported area type")

tags = self.tags.copy()
if "start_time" in img.data.attrs:
tags.update({'TIFFTAG_DATETIME': img.data.attrs["start_time"].strftime(
"%Y:%m:%d %H:%M:%S")})

dst_ds.SetMetadata(tags, '')

def save_image(self, img, filename=None, floating_point=None,
compute=True, **kwargs):
# Check for epsg code.
srs.ImportFromEPSG(int(
area.proj_dict['init'].lower().split('epsg:')[1]))
except (KeyError, IndexError):
pass
srs = srs.ExportToWkt()
dst_ds.SetProjection(srs)
except AttributeError:
LOG.warning(
"Can't save geographic information to geotiff, unsupported area type")

def _create_file(self, filename, img, gformat, g_opts, datasets, mode):
num_bands = len(mode)
if mode[-1] == 'A':
g_opts.append("ALPHA=YES")

def _delayed_create(create_opts, datasets, area, start_time, tags):
raster = gdal.GetDriverByName("GTiff")
dst_ds = raster.Create(*create_opts)
self._gdal_write_datasets(dst_ds, datasets)

# Create raster GeoTransform based on upper left corner and pixel
# resolution ... if not overwritten by argument geotransform.
if "area" is None:
LOG.warning("No 'area' metadata found in image")
else:
self._gdal_write_geo(dst_ds, area)

if start_time is not None:
tags.update({'TIFFTAG_DATETIME': start_time.strftime(
"%Y:%m:%d %H:%M:%S")})

dst_ds.SetMetadata(tags, '')

create_opts = (filename, img.width, img.height, num_bands, gformat, g_opts)
delayed = dask.delayed(_delayed_create)(
create_opts, datasets, img.data.attrs.get('area'),
img.data.attrs.get('start_time'),
self.tags.copy())
return delayed

def save_image(self, img, filename=None, dtype=None, fill_value=None,
floating_point=None, compute=True, **kwargs):
"""Save the image to the given *filename* in geotiff_ format.
`floating_point` allows the saving of
'L' mode images in floating point format if set to True.
Expand All @@ -205,43 +197,51 @@ def save_image(self, img, filename=None, floating_point=None,
if k in self.GDAL_OPTIONS:
gdal_options[k] = kwargs[k]

floating_point = floating_point if floating_point is not None else self.floating_point
if floating_point is not None:
import warnings
warnings.warn("'floating_point' is deprecated, use"
"'dtype=np.float64' instead.",
DeprecationWarning)
dtype = np.float64
dtype = dtype if dtype is not None else self.dtype
if dtype is None:
dtype = np.uint8

if "alpha" in kwargs:
raise ValueError(
"Keyword 'alpha' is automatically set and should not be specified")
if floating_point:
"Keyword 'alpha' is automatically set based on 'fill_value' "
"and should not be specified")
if np.issubdtype(dtype, np.floating):
if img.mode != "L":
raise ValueError("Image must be in 'L' mode for floating "
"point geotiff saving")
fill_value = np.nan
datasets, mode = img._finalize(fill_value=fill_value,
dtype=np.float64)
gformat = gdal.GDT_Float64
opacity = 0
else:
nbits = int(gdal_options.get("nbits", "8"))
if nbits > 16:
dtype = np.uint32
gformat = gdal.GDT_UInt32
elif nbits > 8:
dtype = np.uint16
gformat = gdal.GDT_UInt16
else:
dtype = np.uint8
gformat = gdal.GDT_Byte
opacity = np.iinfo(dtype).max
datasets, mode = img._finalize(dtype=dtype)

LOG.debug("Saving to GeoTiff: %s", filename)

g_opts = ["{0}={1}".format(k.upper(), str(v))
for k, v in gdal_options.items()]

ensure_dir(filename)
delayed = dask.delayed(self._create_file)(filename, img, gformat,
g_opts, opacity, datasets,
mode)
if compute:
return delayed.compute()
return delayed
if fill_value is None:
LOG.debug("Alpha band not supported for float geotiffs, "
"setting fill value to 'NaN'")
fill_value = np.nan

try:
import rasterio # noqa
# we can use the faster rasterio-based save
return img.save(filename, fformat='tif', fill_value=fill_value,
dtype=dtype, compute=compute, **gdal_options)
except ImportError:
LOG.warning("Using legacy/slower geotiff save method, install "
"'rasterio' for faster saving.")
# force to numpy dtype object
dtype = np.dtype(dtype)
gformat = NP2GDAL[dtype.type]

gdal_options['nbits'] = int(gdal_options.get('nbits',
dtype.itemsize * 8))
datasets, mode = img._finalize(fill_value=fill_value, dtype=dtype)
LOG.debug("Saving to GeoTiff: %s", filename)
g_opts = ["{0}={1}".format(k.upper(), str(v))
for k, v in gdal_options.items()]

ensure_dir(filename)
delayed = self._create_file(filename, img, gformat, g_opts,
datasets, mode)
if compute:
return delayed.compute()
return delayed