Skip to content

Commit

Permalink
adding broadcast operations to raster object, other modifications to …
Browse files Browse the repository at this point in the history
…vector object
  • Loading branch information
Will Bierbower authored and Will Bierbower committed Apr 6, 2015
1 parent c0dab75 commit 10246c9
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 50 deletions.
141 changes: 98 additions & 43 deletions fauxgeo/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,49 +93,95 @@ def __len__(self):
return self.band_count()

def __mul__(self, raster):
def mul_closer(nodata):
def mul(x, y):
if nodata in [x, y]:
return nodata
return x * y
return mul
return self.local_op(raster, mul_closer)
if type(raster) in [float, int]:
def mul_closure(nodata):
def mul(x):
if nodata in [x]:
return nodata
return x * raster
return mul
return self.local_op(raster, mul_closure, broadcast=True)
else:
def mul_closure(nodata):
def mul(x, y):
if nodata in [x, y]:
return nodata
return x * y
return mul
return self.local_op(raster, mul_closure)

def __div__(self, raster):
def div_closer(nodata):
def div(x, y):
if nodata in [x, y]:
return nodata
return x / y
return div
return self.local_op(raster, div_closer)
if type(raster) in [float, int]:
def div_closure(nodata):
def div(x):
if nodata in [x]:
return nodata
return x / raster
return div
return self.local_op(raster, div_closure, broadcast=True)
else:
def div_closure(nodata):
def div(x, y):
if nodata in [x, y]:
return nodata
return x / y
return div
return self.local_op(raster, div_closure)

def __add__(self, raster):
def add_closer(nodata):
def add(x, y):
if nodata in [x, y]:
return nodata
return x + y
return add
return self.local_op(raster, add_closer)
if type(raster) in [float, int]:
def add_closure(nodata):
def add(x):
if nodata in [x]:
return nodata
return x + raster
return add
return self.local_op(raster, add_closure, broadcast=True)
else:
def add_closure(nodata):
def add(x, y):
if nodata in [x, y]:
return nodata
return x + y
return add
return self.local_op(raster, add_closure)

def __sub__(self, raster):
def sub_closer(nodata):
def sub(x, y):
if nodata in [x, y]:
return nodata
return x - y
return sub
return self.local_op(raster, sub_closer)
if type(raster) in [float, int]:
def sub_closure(nodata):
def sub(x):
if nodata in [x]:
return nodata
return x - raster
return sub
return self.local_op(raster, sub_closure, broadcast=True)
else:
def sub_closure(nodata):
def sub(x, y):
if nodata in [x, y]:
return nodata
return x - y
return sub
return self.local_op(raster, sub_closure)

def __pow__(self, raster):
def pow_closer(nodata):
def pow(x, y):
if nodata in [x, y]:
return nodata
return x**y
return pow
return self.local_op(raster, pow_closer)
if type(raster) in [float, int]:
# Implement broadcast operation
def pow_closure(nodata):
def powe(x):
if nodata in [x]:
return nodata
return x**raster
return powe
return self.local_op(raster, pow_closure, broadcast=True)
else:
def pow_closure(nodata):
def powe(x, y):
if nodata in [x, y]:
return nodata
return x**y
return powe
return self.local_op(raster, pow_closure)

def __eq__(self, raster):
if self.is_aligned(raster) and (self.get_shape() == raster.get_shape()):
Expand Down Expand Up @@ -318,6 +364,9 @@ def get_aoi_as_shapefile(self, uri):
# outFeature.SetGeometry(wkb)
# outLayer.CreateFeature(outFeature)

def get_cell_area(self):
raise NotImplementedError

def set_band(self, masked_array):
'''Currently works for rasters with only one band'''
assert(len(masked_array) == self.get_rows())
Expand Down Expand Up @@ -476,19 +525,25 @@ def overlay(self, raster):
def to_vector(self):
raise NotImplementedError

def local_op(self, raster, pixel_op_closer):
assert(self.is_aligned(raster))
assert(self.get_nodata(1) == raster.get_nodata(1))
def local_op(self, raster, pixel_op_closure, broadcast=False):
bounding_box_mode = "dataset"
resample_method = "nearest"

if not broadcast:
assert(self.is_aligned(raster))
assert(self.get_nodata(1) == raster.get_nodata(1))
dataset_uri_list = [self.uri, raster.uri]
resample_list = [resample_method]*2
else:
dataset_uri_list = [self.uri]
resample_list = [resample_method]

nodata = self.get_nodata(1)
pixel_op = pixel_op_closer(nodata)
dataset_uri_list = [self.uri, raster.uri]
pixel_op = pixel_op_closure(nodata)
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(self.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(self.uri)
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)
bounding_box_mode = "dataset"
resample_method = "nearest"

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
Expand All @@ -498,7 +553,7 @@ def local_op(self, raster, pixel_op_closer):
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=[resample_method]*2,
resample_method_list=resample_list,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
Expand Down
17 changes: 10 additions & 7 deletions fauxgeo/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,14 @@ def from_shapely(self, shapely_object, proj, driver='ESRI Shapefile'):

@classmethod
def from_file(self, uri, driver='ESRI Shapefile'):
datasource_uri = pygeo.geoprocessing.temporary_filename()
dst_uri = pygeo.geoprocessing.temporary_filename()
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
# assert existence
shutil.copyfile(uri, datasource_uri)
return Vector(datasource_uri, driver)
src_uri = os.path.splitext(uri)[0]
for ext in ['.shp', '.shx', '.dbf', '.prj']:
if os.path.exists(src_uri+ext):
shutil.copyfile(src_uri+ext, dst_uri+ext)
return Vector(dst_uri, driver)

@classmethod
def from_tempfile(self, uri, driver='ESRI Shapefile'):
Expand Down Expand Up @@ -116,10 +118,11 @@ def _repr_png_(self):
raise NotImplementedError

def save_vector(self, uri):
fp = os.path.splitext(self.uri)[0]
src_uri = os.path.splitext(self.uri)[0]
dst_uri = os.path.splitext(uri)[0]
for ext in ['.shp', '.shx', '.dbf', '.prj']:
if os.path.exists(fp+ext):
shutil.copyfile(fp+ext, uri+ext)
if os.path.exists(src_uri+ext):
shutil.copyfile(src_uri+ext, dst_uri+ext)

def feature_count(self):
self._open_datasource()
Expand Down

0 comments on commit 10246c9

Please sign in to comment.