-
Notifications
You must be signed in to change notification settings - Fork 17
/
xrimage.py
1499 lines (1256 loc) · 58.4 KB
/
xrimage.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2017-2018
#
# Author(s):
#
# Martin Raspaud <martin.raspaud@smhi.se>
# Adam Dybbroe <adam.dybbroe@smhi.se>
# Esben S. Nielsen <esn@dmi.dk>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""This module defines the XRImage class.
It overlaps largely with the PIL library, but has the advantage of using
:class:`~xarray.DataArray` objects backed by :class:`dask arrays
<dask.array.Array>` as pixel arrays. This allows for invalid values to
be tracked, metadata to be assigned, and stretching to be lazy
evaluated. With the optional ``rasterio`` library installed dask array
chunks can be saved in parallel.
"""
import logging
import os
import threading
import warnings
from contextlib import suppress
import dask
import dask.array as da
import numpy as np
import xarray as xr
from PIL import Image as PILImage
from dask.delayed import delayed
from trollimage.image import check_image_format
try:
import rasterio
from rasterio.enums import Resampling
except ImportError:
rasterio = None
try:
# rasterio 1.0+
from rasterio.windows import Window
except ImportError:
# raster 0.36.0
# remove this once rasterio 1.0+ is officially available
def Window(x_off, y_off, x_size, y_size):
"""Replace the missing Window object in rasterio < 1.0."""
return (y_off, y_off + y_size), (x_off, x_off + x_size)
logger = logging.getLogger(__name__)
class RIOFile(object):
"""Rasterio wrapper to allow da.store to do window saving."""
def __init__(self, path, mode='w', **kwargs):
"""Initialize the object."""
self.path = path
self.mode = mode
self.kwargs = kwargs
self.rfile = None
self.lock = threading.Lock()
@property
def width(self):
"""Width of the band images."""
return self.kwargs['width']
@property
def height(self):
"""Height of the band images."""
return self.kwargs['height']
@property
def closed(self):
"""Check if the file is closed."""
return self.rfile is None or self.rfile.closed
def open(self, mode=None):
"""Open the file."""
mode = mode or self.mode
if self.closed:
self.rfile = rasterio.open(self.path, mode, **self.kwargs)
def close(self):
"""Close the file."""
with self.lock:
if not self.closed:
self.rfile.close()
def __enter__(self):
"""Enter method."""
self.open()
return self
def __exit__(self, exc_type, exc_value, traceback):
"""Exit method."""
self.close()
def __del__(self):
"""Delete the instance."""
with suppress(IOError, OSError):
self.close()
@property
def colorinterp(self):
"""Return the color interpretation of the image."""
return self.rfile.colorinterp
@colorinterp.setter
def colorinterp(self, val):
if rasterio.__version__.startswith("0."):
# not supported in older versions, set by PHOTOMETRIC tag
logger.warning("Rasterio 1.0+ required for setting colorinterp")
else:
self.rfile.colorinterp = val
def write(self, *args, **kwargs):
"""Write to the file."""
with self.lock:
self.open('a')
return self.rfile.write(*args, **kwargs)
def build_overviews(self, *args, **kwargs):
"""Write overviews."""
with self.lock:
self.open('a')
return self.rfile.build_overviews(*args, **kwargs)
def update_tags(self, *args, **kwargs):
"""Update tags."""
with self.lock:
self.open('a')
return self.rfile.update_tags(*args, **kwargs)
class RIOTag:
"""Rasterio wrapper to allow da.store on tag."""
def __init__(self, rfile, name):
"""Init the rasterio tag."""
self.rfile = rfile
self.name = name
def __setitem__(self, key, item):
"""Put the data in the tag."""
kwargs = {self.name: item.item()}
self.rfile.update_tags(**kwargs)
def close(self):
"""Close the file."""
return self.rfile.close()
class RIODataset:
"""A wrapper for a rasterio dataset."""
def __init__(self, rfile, overviews=None, overviews_resampling=None,
overviews_minsize=256):
"""Init the rasterio dataset."""
self.rfile = rfile
self.overviews = overviews
if overviews_resampling is None:
overviews_resampling = 'nearest'
self.overviews_resampling = Resampling[overviews_resampling]
self.overviews_minsize = overviews_minsize
def __setitem__(self, key, item):
"""Put the data chunk in the image."""
if len(key) == 3:
indexes = list(range(
key[0].start + 1,
key[0].stop + 1,
key[0].step or 1
))
y = key[1]
x = key[2]
else:
indexes = 1
y = key[0]
x = key[1]
chy_off = y.start
chy = y.stop - y.start
chx_off = x.start
chx = x.stop - x.start
# band indexes
self.rfile.write(item, window=Window(chx_off, chy_off, chx, chy),
indexes=indexes)
def close(self):
"""Close the file."""
if self.overviews is not None:
overviews = self.overviews
# it's an empty list
if len(overviews) == 0:
from rasterio.rio.overview import get_maximum_overview_level
width = self.rfile.width
height = self.rfile.height
max_level = get_maximum_overview_level(
width, height, self.overviews_minsize)
overviews = [2 ** j for j in range(1, max_level + 1)]
logger.debug('Building overviews %s with %s resampling',
str(overviews), self.overviews_resampling.name)
self.rfile.build_overviews(overviews, resampling=self.overviews_resampling)
return self.rfile.close()
def color_interp(data):
"""Get the color interpretation for this image."""
from rasterio.enums import ColorInterp as ci
modes = {'L': [ci.gray],
'LA': [ci.gray, ci.alpha],
'YCbCr': [ci.Y, ci.Cb, ci.Cr],
'YCbCrA': [ci.Y, ci.Cb, ci.Cr, ci.alpha]}
try:
mode = ''.join(data['bands'].values)
return modes[mode]
except KeyError:
colors = {'R': ci.red,
'G': ci.green,
'B': ci.blue,
'A': ci.alpha,
'C': ci.cyan,
'M': ci.magenta,
'Y': ci.yellow,
'H': ci.hue,
'S': ci.saturation,
'L': ci.lightness,
'K': ci.black,
}
return [colors[band] for band in data['bands'].values]
def combine_scales_offsets(*args):
"""Combine ``(scale, offset)`` tuples in one, considering they are applied from left to right.
For example, if we have our base data called ```orig_data`` and apply to it
``(scale_1, offset_1)``, then ``(scale_2, offset_2)`` such that::
data_1 = orig_data * scale_1 + offset_1
data_2 = data_1 * scale_2 + offset_2
this function will return the tuple ``(scale, offset)`` such that::
data_2 = orig_data * scale + offset
given the arguments ``(scale_1, offset_1), (scale_2, offset_2)``.
"""
cscale = 1
coffset = 0
for scale, offset in args:
cscale *= scale
coffset = coffset * scale + offset
return cscale, coffset
def invert_scale_offset(scale, offset):
"""Invert scale and offset to allow reverse transformation.
Ie, it will return ``rscale, roffset`` such that::
orig_data = rscale * data + roffset
if::
data = scale * orig_data + offset
"""
return 1 / scale, -offset / scale
@delayed(nout=1, pure=True)
def delayed_pil_save(img, *args, **kwargs):
"""Dask delayed saving of PIL Image object.
Special wrapper to handle `fill_value` try/except catch and provide a
more useful error message.
"""
try:
img.save(*args, **kwargs)
except OSError as e:
# ex: cannot write mode LA as JPEG
if "A as JPEG" in str(e):
new_msg = ("Image mode not supported for this format. Specify "
"`fill_value=0` to set invalid values to black.")
raise OSError(new_msg) from e
raise
class XRImage(object):
"""Image class using an :class:`xarray.DataArray` as internal storage.
It can be saved to a variety of image formats, but if Rasterio is
installed, it can save to geotiff and jpeg2000 with geographical
information.
The enhancements functions are recording some parameters in the image's
data attribute called `enhancement_history`.
"""
def __init__(self, data):
"""Initialize the image with a :class:`~xarray.DataArray`."""
data = self._correct_dims(data)
# 'data' is an XArray, get the data from it as a dask array
if not isinstance(data.data, da.Array):
logger.debug("Convert image data to dask array")
data.data = da.from_array(data.data, chunks=(data.sizes['bands'], 4096, 4096))
self.data = data
self.height, self.width = self.data.sizes['y'], self.data.sizes['x']
self.palette = None
@staticmethod
def _correct_dims(data):
"""Standardize dimensions to bands, y, and x."""
if not hasattr(data, 'dims'):
raise TypeError("Data must have a 'dims' attribute.")
if 'y' not in data.dims or 'x' not in data.dims:
if data.ndim != 2:
raise ValueError("Data must have a 'y' and 'x' dimension")
# rename dimensions so we can use them
# don't rename 'x' or 'y' if they already exist
if 'y' not in data.dims:
# find a dimension that isn't 'x'
old_dim = [d for d in data.dims if d != 'x'][0]
data = data.rename({old_dim: 'y'})
if 'x' not in data.dims:
# find a dimension that isn't 'y'
old_dim = [d for d in data.dims if d != 'y'][0]
data = data.rename({old_dim: 'x'})
if "bands" not in data.dims:
if data.ndim <= 2:
data = data.expand_dims('bands')
data['bands'] = ['L']
else:
raise ValueError("No 'bands' dimension provided.")
# doesn't actually copy the data underneath
# we don't want our operations to change the user's data
# we do this last in case `expand_dims` made the data read only
data = data.copy()
return data
@property
def mode(self):
"""Mode of the image."""
return ''.join(self.data['bands'].values)
def save(self, filename, fformat=None, fill_value=None, compute=True,
keep_palette=False, cmap=None, **format_kwargs):
"""Save the image to the given *filename*.
Args:
filename (str): Output filename
fformat (str): File format of output file (optional). Can be
one of many image formats supported by the
`rasterio` or `PIL` libraries ('jpg', 'png',
'tif'). By default this is determined by the
extension of the provided filename.
If the format allows, geographical information will
be saved to the ouput file, in the form of grid
mapping or ground control points.
fill_value (float): Replace invalid data values with this value
and do not produce an Alpha band. Default
behavior is to create an alpha band.
compute (bool): If True (default) write the data to the file
immediately. If False the return value is either
a `dask.Delayed` object or a tuple of
``(source, target)`` to be passed to
`dask.array.store`.
keep_palette (bool): Saves the palettized version of the image if
set to True. False by default.
cmap (Colormap or dict): Colormap to be applied to the image when
saving with rasterio, used with
keep_palette=True. Should be uint8.
format_kwargs: Additional format options to pass to `rasterio`
or `PIL` saving methods. Any format argument passed
at this stage would be superseeded by `fformat`.
Returns:
Either `None` if `compute` is True or a `dask.Delayed` object or
``(source, target)`` pair to be passed to `dask.array.store`.
If compute is False the return value depends on format and how
the image backend is used. If ``(source, target)`` is provided
then target is an open file-like object that must be closed by
the caller.
"""
kwformat = format_kwargs.pop('format', None)
fformat = fformat or kwformat or os.path.splitext(filename)[1][1:]
if fformat in ('tif', 'tiff', 'jp2') and rasterio:
return self.rio_save(filename, fformat=fformat,
fill_value=fill_value, compute=compute,
keep_palette=keep_palette, cmap=cmap,
**format_kwargs)
return self.pil_save(filename, fformat, fill_value,
compute=compute, **format_kwargs)
def rio_save(self, filename, fformat=None, fill_value=None,
dtype=np.uint8, compute=True, tags=None,
keep_palette=False, cmap=None, overviews=None,
overviews_minsize=256, overviews_resampling=None,
include_scale_offset_tags=False,
scale_offset_tags=None,
**format_kwargs):
"""Save the image using rasterio.
Args:
filename (string): The filename to save to.
fformat (string): The format to save to. If not specified (default),
it will be infered from the file extension.
fill_value (number): The value to fill the missing data with.
Default is ``None``, translating to trying to keep the data
transparent.
dtype (np.dtype): The type to save the data to. Defaults to
np.uint8.
compute (bool): Whether (default) or not to compute the lazy data.
tags (dict): Tags to include in the file.
keep_palette (bool): Whether or not (default) to keep the image in
P mode.
cmap (colormap): The colormap to use for the data.
overviews (list): The reduction factors of the overviews to include
in the image, eg::
img.rio_save('myfile.tif', overviews=[2, 4, 8, 16])
If provided as an empty list, then levels will be
computed as powers of two until the last level has less
pixels than `overviews_minsize`.
Default is to not add overviews.
overviews_minsize (int): Minimum number of pixels for the smallest
overview size generated when `overviews` is auto-generated.
Defaults to 256.
overviews_resampling (str): Resampling method
to use when generating overviews. This must be the name of an
enum value from :class:`rasterio.enums.Resampling` and
only takes effect if the `overviews` keyword argument is
provided. Common values include `nearest` (default),
`bilinear`, `average`, and many others. See the rasterio
documentation for more information.
scale_offset_tags (Tuple[str, str] or None)
If set to a ``(str, str)`` tuple, scale and offset will be
stored in GDALMetaData tags. Those can then be used to
retrieve the original data values from pixel values.
Returns:
The delayed or computed result of the saving.
"""
fformat = fformat or os.path.splitext(filename)[1][1:]
drivers = {'jpg': 'JPEG',
'png': 'PNG',
'tif': 'GTiff',
'tiff': 'GTiff',
'jp2': 'JP2OpenJPEG'}
driver = drivers.get(fformat, fformat)
if include_scale_offset_tags:
warnings.warn(
"include_scale_offset_tags is deprecated, please use "
"scale_offset_tags to indicate tag labels",
DeprecationWarning)
scale_offset_tags = scale_offset_tags or ("scale", "offset")
if tags is None:
tags = {}
data, mode = self.finalize(fill_value, dtype=dtype,
keep_palette=keep_palette)
data = data.transpose('bands', 'y', 'x')
crs = None
gcps = None
transform = None
if driver in ['GTiff', 'JP2OpenJPEG']:
if not np.issubdtype(data.dtype, np.floating):
format_kwargs.setdefault('compress', 'DEFLATE')
photometric_map = {
'RGB': 'RGB',
'RGBA': 'RGB',
'CMYK': 'CMYK',
'CMYKA': 'CMYK',
'YCBCR': 'YCBCR',
'YCBCRA': 'YCBCR',
}
if mode.upper() in photometric_map:
format_kwargs.setdefault('photometric',
photometric_map[mode.upper()])
try:
area = data.attrs['area']
if rasterio.__gdal_version__ >= '3':
wkt_version = 'WKT2_2018'
else:
wkt_version = 'WKT1_GDAL'
if hasattr(area, 'crs'):
crs = rasterio.crs.CRS.from_wkt(area.crs.to_wkt(version=wkt_version))
else:
crs = rasterio.crs.CRS(data.attrs['area'].proj_dict)
west, south, east, north = data.attrs['area'].area_extent
height, width = data.sizes['y'], data.sizes['x']
transform = rasterio.transform.from_bounds(west, south,
east, north,
width, height)
except KeyError: # No area
logger.info("Couldn't create geotransform")
except AttributeError:
try:
gcps = data.attrs['area'].lons.attrs['gcps']
crs = data.attrs['area'].lons.attrs['crs']
except KeyError:
logger.info("Couldn't create geotransform")
stime = data.attrs.get("start_time")
if stime:
stime_str = stime.strftime("%Y:%m:%d %H:%M:%S")
tags.setdefault('TIFFTAG_DATETIME', stime_str)
elif driver == 'JPEG' and 'A' in mode:
raise ValueError('JPEG does not support alpha')
if scale_offset_tags:
scale_label, offset_label = scale_offset_tags
scale, offset = self.get_scaling_from_history(data.attrs.get('enhancement_history', []))
tags[scale_label], tags[offset_label] = invert_scale_offset(scale, offset)
# FIXME add metadata
r_file = RIOFile(filename, 'w', driver=driver,
width=data.sizes['x'], height=data.sizes['y'],
count=data.sizes['bands'],
dtype=dtype,
nodata=fill_value,
crs=crs,
transform=transform,
gcps=gcps,
**format_kwargs)
r_file.open()
if not keep_palette:
r_file.colorinterp = color_interp(data)
if keep_palette and cmap is not None:
if data.dtype != 'uint8':
raise ValueError('Rasterio only supports 8-bit colormaps')
try:
from trollimage.colormap import Colormap
cmap = cmap.to_rio() if isinstance(cmap, Colormap) else cmap
r_file.rfile.write_colormap(1, cmap)
except AttributeError:
raise ValueError("Colormap is not formatted correctly")
tags, da_tags = self._split_regular_vs_lazy_tags(tags, r_file)
r_file.rfile.update_tags(**tags)
r_dataset = RIODataset(r_file, overviews,
overviews_resampling=overviews_resampling,
overviews_minsize=overviews_minsize)
to_store = (data.data, r_dataset)
if da_tags:
to_store = list(zip(*([to_store] + da_tags)))
if compute:
# write data to the file now
res = da.store(*to_store)
to_close = to_store[1]
if not isinstance(to_close, tuple):
to_close = [to_close]
for item in to_close:
item.close()
return res
# provide the data object and the opened file so the caller can
# store them when they would like. Caller is responsible for
# closing the file
return to_store
@staticmethod
def _split_regular_vs_lazy_tags(tags, r_file):
"""Split tags into regular vs lazy (dask) tags."""
da_tags = []
for key, val in list(tags.items()):
try:
if isinstance(val.data, da.Array):
da_tags.append((val.data, RIOTag(r_file, key)))
tags.pop(key)
else:
tags[key] = val.item()
except AttributeError:
continue
return tags, da_tags
def pil_save(self, filename, fformat=None, fill_value=None,
compute=True, **format_kwargs):
"""Save the image to the given *filename* using PIL.
For now, the compression level [0-9] is ignored, due to PIL's
lack of support. See also :meth:`save`.
"""
fformat = fformat or os.path.splitext(filename)[1][1:]
fformat = check_image_format(fformat)
if fformat == 'png':
# Take care of GeoImage.tags (if any).
format_kwargs['pnginfo'] = self._pngmeta()
img = self.pil_image(fill_value, compute=False)
delay = delayed_pil_save(img, filename, fformat, **format_kwargs)
if compute:
return delay.compute()
return delay
def get_scaling_from_history(self, history=None):
"""Merge the scales and offsets from the history.
If ``history`` isn't provided, the history of the current image will be
used.
"""
if history is None:
history = self.data.attrs.get('enhancement_history', [])
try:
scaling = [(item['scale'], item['offset']) for item in history]
except KeyError as err:
raise NotImplementedError('Can only get combine scaling from a list of scaling operations: %s' % str(err))
return combine_scales_offsets(*scaling)
@delayed(nout=1, pure=True)
def _delayed_apply_pil(self, fun, pil_image, fun_args, fun_kwargs,
image_metadata=None, output_mode=None):
if fun_args is None:
fun_args = tuple()
if fun_kwargs is None:
fun_kwargs = dict()
if image_metadata is None:
image_metadata = dict()
new_img = fun(pil_image, image_metadata, *fun_args, **fun_kwargs)
if output_mode is not None:
new_img = new_img.convert(output_mode)
return np.array(new_img) / self.data.dtype.type(255.0)
def apply_pil(self, fun, output_mode, pil_args=None, pil_kwargs=None, fun_args=None, fun_kwargs=None):
"""Apply a function `fun` on the pillow image corresponding to the instance of the XRImage.
The function shall take a pil image as first argument, and is then passed fun_args and fun_kwargs.
In addition, the current images's metadata is passed as a keyword argument called `image_mda`.
It is expected to return the modified pil image.
This function returns a new XRImage instance with the modified image data.
The pil_args and pil_kwargs are passed to the `pil_image` method of the XRImage instance.
"""
if pil_args is None:
pil_args = tuple()
if pil_kwargs is None:
pil_kwargs = dict()
pil_image = self.pil_image(*pil_args, compute=False, **pil_kwargs)
# HACK: aggdraw.Font objects cause segmentation fault in dask tokenize
# Remove this when aggdraw is either updated to allow type(font_obj)
# or pycoast is updated to not accept Font objects
# See https://github.com/pytroll/pycoast/issues/43
# The last positional argument to the _burn_overlay function in Satpy
# is the 'overlay' dict. This could include aggdraw.Font objects so we
# completely remove it.
delayed_kwargs = {}
if fun.__name__ == "_burn_overlay":
from dask.base import tokenize
from dask.utils import funcname
func = self._delayed_apply_pil
if fun_args is None:
fun_args = tuple()
if fun_kwargs is None:
fun_kwargs = dict()
tokenize_args = (fun, pil_image, fun_args[:-1], fun_kwargs,
self.data.attrs, output_mode)
dask_key_name = "%s-%s" % (
funcname(func),
tokenize(func.key, *tokenize_args, pure=True),
)
delayed_kwargs['dask_key_name'] = dask_key_name
new_array = self._delayed_apply_pil(fun, pil_image, fun_args, fun_kwargs,
self.data.attrs, output_mode,
**delayed_kwargs)
bands = len(output_mode)
arr = da.from_delayed(new_array, dtype=self.data.dtype,
shape=(self.data.sizes['y'], self.data.sizes['x'], bands))
new_data = xr.DataArray(arr, dims=['y', 'x', 'bands'],
coords={'y': self.data.coords['y'],
'x': self.data.coords['x'],
'bands': list(output_mode)},
attrs=self.data.attrs)
return XRImage(new_data)
def _pngmeta(self):
"""Return GeoImage.tags as a PNG metadata object.
Inspired by:
public domain, Nick Galbreath
http://blog.modp.com/2007/08/python-pil-and-png-metadata-take-2.html
"""
reserved = ('interlace', 'gamma', 'dpi', 'transparency', 'aspect')
try:
tags = self.tags
except AttributeError:
tags = {}
# Undocumented class
from PIL import PngImagePlugin
meta = PngImagePlugin.PngInfo()
# Copy from tags to new dict
for k__, v__ in tags.items():
if k__ not in reserved:
meta.add_text(k__, v__, 0)
return meta
def _create_alpha(self, data, fill_value=None):
"""Create an alpha band DataArray object.
If `fill_value` is provided and input data is an integer type
then it is used to determine invalid "null" pixels instead of
xarray's `isnull` and `notnull` methods.
The returned array is 1 where data is valid, 0 where invalid.
"""
not_alpha = [b for b in data.coords['bands'].values if b != 'A']
null_mask = data.sel(bands=not_alpha)
if np.issubdtype(data.dtype, np.integer) and fill_value is not None:
null_mask = null_mask != fill_value
else:
null_mask = null_mask.notnull()
# if any of the bands are valid, we don't want transparency
null_mask = null_mask.any(dim='bands')
null_mask = null_mask.expand_dims('bands')
null_mask['bands'] = ['A']
# changes to null_mask attrs should not effect the original attrs
# XRImage never uses them either
null_mask.attrs = {}
return null_mask
def _add_alpha(self, data, alpha=None):
"""Create an alpha channel and concatenate it to the provided data.
If ``data`` is an integer type then the alpha band will be scaled
to use the smallest (min) value as fully transparent and the largest
(max) value as fully opaque. If a `_FillValue` attribute is found for
integer type data then it is used to identify null values in the data.
Otherwise xarray's `isnull` is used.
For float types the alpha band spans 0 to 1.
"""
fill_value = data.attrs.get('_FillValue', None) # integer fill value
null_mask = alpha if alpha is not None else self._create_alpha(data, fill_value)
# if we are using integer data, then alpha needs to be min-int to max-int
# otherwise for floats we want 0 to 1
if np.issubdtype(data.dtype, np.integer):
# xarray sometimes upcasts this calculation, so cast again
null_mask = self._scale_to_dtype(null_mask, data.dtype).astype(data.dtype)
attrs = data.attrs.copy()
data = xr.concat([data, null_mask], dim="bands")
data.attrs = attrs
return data
def _get_dtype_scale_offset(self, dtype, fill_value):
dinfo = np.iinfo(dtype)
scale = dinfo.max - dinfo.min
offset = dinfo.min
if fill_value is not None:
if fill_value == dinfo.min:
# leave the lowest value for fill value only
offset = offset + 1
scale = scale - 1
elif fill_value == dinfo.max:
# leave the top value for fill value only
scale = scale - 1
else:
warnings.warn(
"Specified fill value will overlap with valid "
"data. To avoid this warning specify a fill_value "
"that is the minimum or maximum for the data type "
"being saved to.")
return scale, offset
def _scale_to_dtype(self, data, dtype, fill_value=None):
"""Scale provided data to dtype range assuming a 0-1 range.
Float input data is assumed to be normalized to a 0 to 1 range.
Integer input data is not scaled, only clipped. A float output
type is not scaled since both outputs and inputs are assumed to
be in the 0-1 range already.
"""
attrs = data.attrs.copy()
if np.issubdtype(dtype, np.integer):
if np.issubdtype(data, np.integer):
# preserve integer data type
data = data.clip(np.iinfo(dtype).min, np.iinfo(dtype).max)
else:
# scale float data (assumed to be 0 to 1) to full integer space
# leave room for fill value if needed
scale, offset = self._get_dtype_scale_offset(dtype, fill_value)
data = data.clip(0, 1) * scale + offset
attrs.setdefault('enhancement_history', list()).append({'scale': scale, 'offset': offset})
data = data.round()
data.attrs = attrs
return data
def _check_modes(self, modes):
"""Check that the image is in one of the given *modes*, raise an exception otherwise."""
if not isinstance(modes, (tuple, list, set)):
modes = [modes]
if self.mode not in modes:
raise ValueError("Image not in suitable mode, expected: %s, got: %s" % (modes, self.mode))
def _from_p(self, mode):
"""Convert the image from P or PA to RGB or RGBA."""
self._check_modes(("P", "PA"))
if not self.palette:
raise RuntimeError("Can't convert palettized image, missing palette.")
pal = np.array(self.palette)
pal = da.from_array(pal, chunks=pal.shape)
if pal.shape[1] == 4:
# colormap's alpha overrides data alpha
mode = "RGBA"
alpha = None
elif self.mode.endswith("A"):
# add a new/fake 'bands' dimension to the end
alpha = self.data.sel(bands="A").data[..., None]
mode = mode + "A" if not mode.endswith("A") else mode
else:
alpha = None
flat_indexes = self.data.sel(bands='P').data.ravel().astype('int64')
dim_sizes = ((key, val) for key, val in self.data.sizes.items() if key != 'bands')
dims, new_shape = zip(*dim_sizes)
dims = dims + ('bands',)
new_shape = new_shape + (pal.shape[1],)
new_data = pal[flat_indexes].reshape(new_shape)
coords = dict(self.data.coords)
coords["bands"] = list(mode)
if alpha is not None:
new_arr = da.concatenate((new_data, alpha), axis=-1)
data = xr.DataArray(new_arr, coords=coords, attrs=self.data.attrs, dims=dims)
else:
data = xr.DataArray(new_data, coords=coords, attrs=self.data.attrs, dims=dims)
return data
def _l2rgb(self, mode):
"""Convert from L (black and white) to RGB."""
self._check_modes(("L", "LA"))
bands = ["L"] * 3
if mode[-1] == "A":
bands.append("A")
data = self.data.sel(bands=bands)
data["bands"] = list(mode)
return data
def convert(self, mode):
"""Convert image to *mode*."""
if mode == self.mode:
return self.__class__(self.data)
if mode not in ["P", "PA", "L", "LA", "RGB", "RGBA"]:
raise ValueError("Mode %s not recognized." % (mode))
if mode == self.mode + "A":
data = self._add_alpha(self.data).data
coords = dict(self.data.coords)
coords["bands"] = list(mode)
data = xr.DataArray(data, coords=coords, attrs=self.data.attrs, dims=self.data.dims)
new_img = XRImage(data)
elif mode + "A" == self.mode:
# Remove the alpha band from our current image
no_alpha = self.data.sel(bands=[b for b in self.data.coords["bands"].data if b != "A"]).data
coords = dict(self.data.coords)
coords["bands"] = list(mode)
data = xr.DataArray(no_alpha, coords=coords, attrs=self.data.attrs, dims=self.data.dims)
new_img = XRImage(data)
elif mode.endswith("A") and not self.mode.endswith("A"):
img = self.convert(self.mode + "A")
new_img = img.convert(mode)
elif self.mode.endswith("A") and not mode.endswith("A"):
img = self.convert(self.mode[:-1])
new_img = img.convert(mode)
else:
cases = {
"P": {"RGB": self._from_p},
"PA": {"RGBA": self._from_p},
"L": {"RGB": self._l2rgb},
"LA": {"RGBA": self._l2rgb}
}
try:
data = cases[self.mode][mode](mode)
new_img = XRImage(data)
except KeyError:
raise ValueError("Conversion from %s to %s not implemented !"
% (self.mode, mode))
if self.mode.startswith('P') and new_img.mode.startswith('P'):
# need to copy the palette
new_img.palette = self.palette
return new_img
def final_mode(self, fill_value=None):
"""Get the mode of the finalized image when provided this fill_value."""
if fill_value is None and not self.mode.endswith('A'):
return self.mode + 'A'
return self.mode
def _add_alpha_and_scale(self, data, ifill, dtype):
alpha = self._create_alpha(data, fill_value=ifill)
data = self._scale_to_dtype(data, dtype)
data = data.astype(dtype)
data = self._add_alpha(data, alpha=alpha)
return data
def _replace_fill_value(self, data, ifill, fill_value, dtype):
# Add fill_value after all other calculations have been done to
# make sure it is not scaled for the data type
if ifill is not None and fill_value is not None:
# cast fill value to output type so we don't change data type
fill_value = dtype(fill_value)
# integer fields have special fill values
data = data.where(data != ifill, dtype(fill_value))
elif fill_value is not None:
data = data.fillna(dtype(fill_value))
return data
def _get_input_fill_value(self, data):
# if the data are integers then this fill value will be used to check for invalid values
if np.issubdtype(data, np.integer):
return data.attrs.get('_FillValue')
return None
def _scale_and_replace_fill_value(self, data, input_fill_value, fill_value, dtype):
# scale float data to the proper dtype
# this method doesn't cast yet so that we can keep track of NULL values
data = self._scale_to_dtype(data, dtype, fill_value)
data = self._replace_fill_value(data, input_fill_value, fill_value, dtype)
return data
def _scale_alpha_or_fill_data(self, data, fill_value, dtype):
input_fill_value = self._get_input_fill_value(data)
needs_alpha = fill_value is None and not self.mode.endswith('A')
if needs_alpha:
# We don't have a fill value or an alpha, let's add an alpha
return self._add_alpha_and_scale(data, input_fill_value, dtype)
return self._scale_and_replace_fill_value(data, input_fill_value, fill_value, dtype)
def finalize(self, fill_value=None, dtype=np.uint8, keep_palette=False):
"""Finalize the image to be written to an output file.
This adds an alpha band or fills data with a fill_value (if
specified). It also scales float data to the output range of the
data type (0-255 for uint8, default). For integer input data
this method assumes the data is already scaled to the proper
desired range. It will still fill in invalid values and add an
alpha band if needed. Integer input data's fill value is
determined by a special ``_FillValue`` attribute in the
``DataArray`` ``.attrs`` dictionary.
Args: