Skip to content

Commit

Permalink
more work on Raster class
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Bierbower authored and Will Bierbower committed Mar 31, 2015
1 parent bce8d31 commit 5e42fdd
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 52 deletions.
148 changes: 118 additions & 30 deletions fauxgeo/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,20 @@ def __init__(self, uri, driver):
self.dataset = None

@classmethod
def from_array(self, output_uri, array, affine, proj, datatype, nodata_val, driver='GTiff'):
def from_array(self, array, affine, proj, datatype, nodata_val, driver='GTiff'):
if len(array.shape) is 2:
num_bands = 1
elif len(array.shape) is 3:
num_bands = len(array)
else:
raise ValueError

if not os.path.isabs(output_uri):
output_uri = os.path.join(os.getcwd(), output_uri)

dataset_uri = pygeo.geoprocessing.temporary_filename()
rows = array.shape[0]
cols = array.shape[1]

driver = gdal.GetDriverByName(driver)
dataset = driver.Create(output_uri, cols, rows, num_bands, datatype)
dataset = driver.Create(dataset_uri, cols, rows, num_bands, datatype)
dataset.SetGeoTransform((affine.to_gdal()))

for band_num in range(num_bands):
Expand All @@ -56,32 +54,92 @@ def from_array(self, output_uri, array, affine, proj, datatype, nodata_val, driv
dataset = None
driver = None

return Raster(output_uri, driver=driver)
return Raster(dataset_uri, driver=driver)

@classmethod
def from_file(self, uri, driver='GTiff'):
dataset_uri = pygeo.geoprocessing.temporary_filename()
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
# assert existence
shutil.copyfile(uri, dataset_uri)
return Raster(dataset_uri, driver)

@classmethod
def from_tempfile(self, uri, driver='GTiff'):
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
return Raster(uri, driver)

@classmethod
def simple_affine(self, top_left_x, top_left_y, pix_width, pix_height):
return Affine(pix_width, 0, top_left_x, 0, -(pix_height), top_left_y)

def __del__(self):
self._delete()

def __exit__(self):
self._delete()

def _delete(self):
os.remove(self.uri)

def __str__(self):
return self.uri

def __len__(self):
return self.band_count()

def __mul__(self, raster):
def mul(x, y):
return x * y
return self.local_op(b, mul)

def __eq__(self):
pass # false if different shape, element-by-element if same shape
def mul_closer(nodata):
def mul(x, y):
if nodata in [x, y]:
return nodata
return x * y
return mul
return self.local_op(raster, mul_closer)

def __div__(self, raster):
def div_closer(nodata):
def div(x, y):
if nodata in [x, y]:
return nodata
return x / y
return div
return self.local_op(raster, div_closer)

def __add__(self, raster):
def add_closer(nodata):
def add(x, y):
if nodata in [x, y]:
return nodata
return x + y
return add
return self.local_op(raster, add_closer)

def __sub__(self, raster):
def sub_closer(nodata):
def sub(x, y):
if nodata in [x, y]:
return nodata
return x - y
return sub
return self.local_op(raster, sub_closer)

def __pow__(self, raster):
def pow_closer(nodata):
def pow(x, y):
if nodata in [x, y]:
return nodata
return x**y
return pow
return self.local_op(raster, pow_closer)

def __eq__(self, raster):
if self.is_aligned(raster) and (self.get_shape() == raster.get_shape()):
return (self.get_bands() == raster.get_bands())
else:
return False

def __getitem__(self):
pass # return numpy slice? Raster object with sliced numpy array?
Expand All @@ -107,6 +165,9 @@ def __repr__(self):
def _repr_png_(self):
raise NotImplementedError

def save_raster(self, uri):
shutil.copyfile(self.uri, uri)

def band_count(self):
self._open_dataset()
count = self.dataset.RasterCount
Expand Down Expand Up @@ -223,13 +284,7 @@ def set_band(self, masked_array):
assert(len(masked_array[0]) == self.get_cols())
assert(self.band_count() == 1)

uri = self.uri
array = masked_array.data
affine = self.get_affine()
proj = self.get_projection()
datatype = self.get_datatype(1)
nodata_val = masked_array.fill_value
Raster.from_array(uri, array, affine, proj, datatype, nodata_val)
raise NotImplementedError

def set_bands(self, array):
# if self.exists:
Expand All @@ -253,7 +308,7 @@ def copy(self, uri):
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
shutil.copyfile(self.uri, uri)
return Raster.from_file(uri)
return Raster.from_tempfile(uri)

def is_aligned(self, raster):
try:
Expand Down Expand Up @@ -290,16 +345,45 @@ def dataset_pixel_op(x, y): return y
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_file(dataset_out_uri)
# temp_raster = Raster.from_file(dataset_out_uri)
return Raster.from_tempfile(dataset_out_uri)
# temp_raster = Raster.from_tempfile(dataset_out_uri)
# temp_raster.copy(raster.uri)
# os.remove(dataset_out_uri)

def align_to(self, raster, resample_method):
'''Currently aligns other raster to this raster - later: union/intersection
'''
assert(self.get_projection() == raster.get_projection())

def dataset_pixel_op(x, y): return y
dataset_uri_list = [raster.uri, self.uri]
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(raster.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(raster.uri)
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)
bounding_box_mode = "dataset"

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
dataset_pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=[resample_method]*2,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_tempfile(dataset_out_uri)

def clip(self, aoi_uri):
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
pygeo.geoprocessing.clip_dataset_uri(
self.uri, aoi_uri, dataset_out_uri)
return Raster.from_file(dataset_out_uri)
return Raster.from_tempfile(dataset_out_uri)

def reproject(self, proj, resample_method, pixel_size=None):
if pixel_size is None:
Expand All @@ -313,7 +397,7 @@ def reproject(self, proj, resample_method, pixel_size=None):
pygeo.geoprocessing.reproject_dataset_uri(
self.uri, pixel_size, wkt, resample_method, dataset_out_uri)

return Raster.from_file(dataset_out_uri)
return Raster.from_tempfile(dataset_out_uri)

def reclass(self, reclass_table, out_nodata=None):
if out_nodata is None:
Expand All @@ -329,23 +413,27 @@ def reclass(self, reclass_table, out_nodata=None):
out_datatype,
out_nodata)

return Raster.from_file(dataset_out_uri)
return Raster.from_tempfile(dataset_out_uri)

def overlay(self, raster):
raise NotImplementedError

def to_vector(self):
raise NotImplementedError

def local_op(self, raster, pixel_op):
def local_op(self, raster, pixel_op_closer):
assert(self.is_aligned(raster))
assert(self.get_nodata(1) == raster.get_nodata(1))

nodata = self.get_nodata(1)
pixel_op = pixel_op_closer(nodata)
dataset_uri_list = [self.uri, raster.uri]
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(raster.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(raster.uri)
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(self.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(self.uri)
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)
bounding_box_mode = "dataset"
resample_method = "nearest"

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
Expand All @@ -359,9 +447,9 @@ def local_op(self, raster, pixel_op):
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)
vectorize_op=True)

return Raster.from_file(dataset_out_uri)
return Raster.from_tempfile(dataset_out_uri)

def _open_dataset(self):
self.dataset = gdal.Open(self.uri)
Expand Down
2 changes: 1 addition & 1 deletion fauxgeo/temp_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def from_array(array, affine, proj, datatype, nodata_val, driver='GTiff'):
dataset = None
driver = None

return Raster(uri, driver=driver)
return TempRaster(uri, driver=driver)

@classmethod
def from_file():
Expand Down
38 changes: 17 additions & 21 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@ def setUp(self):
self.proj = 4326
self.datatype = gdal.GDT_Float64
self.nodata_val = -9999
t = tempfile.NamedTemporaryFile(mode='w+')
t_aligned = tempfile.NamedTemporaryFile(mode='w+')
t_misaligned = tempfile.NamedTemporaryFile(mode='w+')
t.close()
t_aligned.close()
t_misaligned.close()

self.raster = Raster.from_array(
t.name, self.array, self.affine, self.proj, self.datatype, self.nodata_val)
self.array, self.affine, self.proj, self.datatype, self.nodata_val)
self.aligned_raster = Raster.from_array(
t_aligned.name, np.zeros(self.shape), self.affine, self.proj, self.datatype, self.nodata_val)
np.zeros(self.shape), self.affine, self.proj, self.datatype, self.nodata_val)
self.misaligned_raster = Raster.from_array(
t_misaligned.name, self.array, self.misaligned_affine, self.proj, self.datatype, self.nodata_val)
self.array, self.misaligned_affine, self.proj, self.datatype, self.nodata_val)

def test_get_functions(self):
assert(self.raster.get_rows() == self.shape[0])
Expand All @@ -56,26 +51,29 @@ def test_get_functions(self):
assert(self.raster.get_projection() == 4326)

def test_set_functions(self):
a = np.ma.masked_equal(np.zeros((self.shape)), 1)
self.raster.set_band(a)
np.testing.assert_array_equal(self.raster.get_band(1), a)
pass

def test_is_aligned(self):
assert(self.raster.is_aligned(self.aligned_raster) == True)
assert(self.raster.is_aligned(self.misaligned_raster) == False)

def test_align(self):
print self.misaligned_raster
assert(self.raster.is_aligned(self.misaligned_raster) == False)
self.raster.align(self.misaligned_raster, "nearest")
print self.misaligned_raster
assert(self.raster.is_aligned(self.misaligned_raster) == True)
new_raster = self.raster.align(self.misaligned_raster, "nearest")
assert(self.raster.is_aligned(new_raster) == True)

def test_align_to(self):
assert(self.raster.is_aligned(self.misaligned_raster) == False)
new_raster = self.misaligned_raster.align_to(self.raster, "nearest")
assert(self.raster.is_aligned(new_raster) == True)

def test_clip(self):
# self.raster.clip()
pass

def test_project(self):
pass
# def test_reproject(self):
# reprojected_raster = self.raster.reproject(26917, "nearest")
# print reprojected_raster.get_band(1)

def test_reclass(self):
pass
Expand All @@ -90,9 +88,7 @@ def test_to_vector(self):
pass

def tearDown(self):
os.remove(self.raster.uri)
os.remove(self.aligned_raster.uri)
os.remove(self.misaligned_raster.uri)
pass


if __name__ == '__main__':
Expand Down

0 comments on commit 5e42fdd

Please sign in to comment.