Skip to content

Commit

Permalink
Update Issue 6 refactoring some raster methods to use vectorize_datas…
Browse files Browse the repository at this point in the history
…ets instead of ReadAsArray
  • Loading branch information
wbierbower committed Apr 29, 2015
1 parent 0c9b8aa commit fc34c48
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 26 deletions.
138 changes: 112 additions & 26 deletions fauxgeo/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,15 +339,19 @@ def get_heatmap_image(self):
raise NotImplementedError

def ones(self):
return self.zeros() + 1
def ones_closure(nodata):
def ones(x):
return np.where(x == x, 1, nodata)
return ones
return self.local_op(0, ones_closure, broadcast=True)


def zeros(self):
array = self.get_band(1).data * 0
affine = self.get_affine()
proj = self.get_projection()
datatype = self.get_datatype(1)
nodata_val = self.get_nodata(1)
return Raster.from_array(array, affine, proj, datatype, nodata_val)
def zeros_closure(nodata):
def zeros(x):
return np.where(x == x, 0, nodata)
return zeros
return self.local_op(0, zeros_closure, broadcast=True)

def band_count(self):
self._open_dataset()
Expand Down Expand Up @@ -546,30 +550,112 @@ def set_bands(self, array):
raise NotImplementedError

def set_datatype(self, datatype):
array = self.get_band(1)
affine = self.get_affine()
proj = self.get_projection()
nodata_val = self.get_nodata(1)
return Raster.from_array(array, affine, proj, datatype, nodata_val)

def pixel_op_closure(nodata):
def copy(x):
return np.where(x == x, x, nodata)
return copy

bounding_box_mode = "dataset"
resample_method = "nearest"

dataset_uri_list = [self.uri]
resample_list = [resample_method]

nodata = self.get_nodata(1)
pixel_op = pixel_op_closure(nodata)
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = datatype
nodata_out = nodata
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=resample_list,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_tempfile(dataset_out_uri)

def set_nodata(self, nodata_val):
array = self.get_band(1).data
src_nodata_val = self.get_nodata(1)
array[array == src_nodata_val] = nodata_val

affine = self.get_affine()
proj = self.get_projection()
datatype = self.get_datatype(1)
return Raster.from_array(array, affine, proj, datatype, nodata_val)
def pixel_op_closure(old_nodata, new_nodata):
def copy(x):
return np.where(x == old_nodata, new_nodata, x)
return copy

bounding_box_mode = "dataset"
resample_method = "nearest"

dataset_uri_list = [self.uri]
resample_list = [resample_method]

old_nodata = self.get_nodata(1)
pixel_op = pixel_op_closure(old_nodata, nodata_val)
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(self.uri)
nodata_out = nodata_val
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=resample_list,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_tempfile(dataset_out_uri)

def set_datatype_and_nodata(self, datatype, nodata_val):
array = self.get_band(1).data
src_nodata_val = self.get_nodata(1)
array[array == src_nodata_val] = nodata_val

affine = self.get_affine()
proj = self.get_projection()
return Raster.from_array(array, affine, proj, datatype, nodata_val)
def pixel_op_closure(old_nodata, new_nodata):
def copy(x):
return np.where(x == old_nodata, new_nodata, x)
return copy

bounding_box_mode = "dataset"
resample_method = "nearest"

dataset_uri_list = [self.uri]
resample_list = [resample_method]

old_nodata = self.get_nodata(1)
pixel_op = pixel_op_closure(old_nodata, nodata_val)
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = datatype
nodata_out = nodata_val
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=resample_list,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_tempfile(dataset_out_uri)

def copy(self, uri=None):
if not uri:
Expand Down Expand Up @@ -810,7 +896,7 @@ def local_op(self, raster, pixel_op_closure, broadcast=False):
assert(self.is_aligned(raster))
try:
assert(self.get_nodata(1) == raster.get_nodata(1))
except:
except AssertionError:
LOGGER.error("Rasters have different nodata values: %f, %f" % (
self.get_nodata(1), raster.get_nodata(1)))
raise AssertionError
Expand Down
87 changes: 87 additions & 0 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,5 +282,92 @@ def test_reclass_masked_values(self):
assert(new_raster.get_band(1)[0, 0] == new_value)


class TestRasterZeros(unittest.TestCase):
def setUp(self):
self.shape = (4, 4)
self.array = np.ones(self.shape)
self.affine = Affine(1, 0, 0, 0, -1, 4)
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
self.factory = RasterFactory(
self.proj, self.datatype, self.nodata_val, *self.shape, affine=self.affine)

def test_zeros(self):
raster = self.factory.alternating(-9999, 2.0)
zero_raster = raster.zeros()
assert(zero_raster.get_band(1)[0, 0] == 0)


class TestRasterOnes(unittest.TestCase):
def setUp(self):
self.shape = (4, 4)
self.array = np.ones(self.shape)
self.affine = Affine(1, 0, 0, 0, -1, 4)
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
self.factory = RasterFactory(
self.proj, self.datatype, self.nodata_val, *self.shape, affine=self.affine)

def test_zeros(self):
raster = self.factory.alternating(-9999, 2.0)
ones_raster = raster.ones()
assert(ones_raster.get_band(1)[0, 0] == 1)


class TestRasterSetDatatype(unittest.TestCase):
def setUp(self):
self.shape = (4, 4)
self.array = np.ones(self.shape)
self.affine = Affine(1, 0, 0, 0, -1, 4)
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
self.factory = RasterFactory(
self.proj, self.datatype, self.nodata_val, *self.shape, affine=self.affine)

def test_set_datatype(self):
raster = self.factory.alternating(-9999, 2.0)
new_raster = raster.set_datatype(gdal.GDT_Int16)
assert(new_raster.get_datatype(1) == gdal.GDT_Int16)


class TestRasterSetNoData(unittest.TestCase):
def setUp(self):
self.shape = (4, 4)
self.array = np.ones(self.shape)
self.affine = Affine(1, 0, 0, 0, -1, 4)
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
self.factory = RasterFactory(
self.proj, self.datatype, self.nodata_val, *self.shape, affine=self.affine)

def test_set_nodata(self):
raster = self.factory.alternating(-9999, 2.0)
new_raster = raster.set_nodata(100)
assert(new_raster.get_nodata(1) == 100)
assert(new_raster.get_band(1).data[0, 0] == 100.0)


class TestRasterSetDatatypeAndNoData(unittest.TestCase):
def setUp(self):
self.shape = (4, 4)
self.array = np.ones(self.shape)
self.affine = Affine(1, 0, 0, 0, -1, 4)
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
self.factory = RasterFactory(
self.proj, self.datatype, self.nodata_val, *self.shape, affine=self.affine)

def test_set_nodata(self):
raster = self.factory.alternating(-9999, 2.0)
new_raster = raster.set_datatype_and_nodata(gdal.GDT_Int16, 100)
assert(new_raster.get_datatype(1) == gdal.GDT_Int16)
assert(new_raster.get_nodata(1) == 100)
assert(new_raster.get_band(1).data[0, 0] == 100)

if __name__ == '__main__':
unittest.main()

1 comment on commit fc34c48

@wbierbower
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update Issue #6

Please sign in to comment.