Skip to content

Commit

Permalink
implementing some higher-level GIS functions in the Raster class
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Bierbower authored and Will Bierbower committed Mar 31, 2015
1 parent 28faae4 commit bce8d31
Show file tree
Hide file tree
Showing 9 changed files with 338 additions and 149 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ before_install:
- sudo apt-get install -y python-numpy

# command to install dependencies, e.g. pip install -r requirements.txt --use-mirrors
install: pip install wheel affine
install: pip install wheel affine pygeoprocessing

# command to run tests, e.g. python setup.py test
script: python setup.py test
Expand Down
2 changes: 2 additions & 0 deletions fauxgeo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

from raster import *
from temp_raster import *
from raster_factory import *

__version__ = '0.1.3'
217 changes: 154 additions & 63 deletions fauxgeo/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
'''

import os
import tempfile
import shutil

import gdal
import osr
import numpy as np
from affine import Affine
from StringIO import StringIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import image
import pygeoprocessing as pygeo


class Raster(object):
Expand All @@ -29,6 +29,9 @@ def from_array(self, output_uri, array, affine, proj, datatype, nodata_val, driv
else:
raise ValueError

if not os.path.isabs(output_uri):
output_uri = os.path.join(os.getcwd(), output_uri)

rows = array.shape[0]
cols = array.shape[1]

Expand Down Expand Up @@ -72,6 +75,11 @@ def __str__(self):
def __len__(self):
return self.band_count()

def __mul__(self, raster):
def mul(x, y):
return x * y
return self.local_op(b, mul)

def __eq__(self):
pass # false if different shape, element-by-element if same shape

Expand All @@ -93,25 +101,11 @@ def __iter__(self):
def __contains__(self):
pass # test numpy raster against all bands?

def _figure_data(self, format):
f = StringIO()
array = self.get_band(1)
# fig = plt.figure()
# ax = plt.subplot(1, 1, 1)
# plt.savefig(f, bbox_inches='tight', format=format)
# plt.imsave(f, array, cmap=cm.Greys_r)
# plt.imshow(array, interpolation="nearest")
fig = plt.figure()
ax = fig.add_subplot(111)
# cax = ax.matshow(array, interpolation='nearest')
# fig.colorbar(cax)
plt.savefig(f, bbox_inches='tight', format=format)
# f.seek(0)
# return f.read()
return None
def __repr__(self):
return self.get_bands().__repr__()

def _repr_png_(self):
return self._figure_data('png')
raise NotImplementedError

def band_count(self):
self._open_dataset()
Expand Down Expand Up @@ -201,6 +195,12 @@ def get_shape(self):
cols = self.get_cols()
return (rows, cols)

def get_projection(self):
self._open_dataset()
RasterSRS = osr.SpatialReference()
RasterSRS.ImportFromWkt(self.dataset.GetProjectionRef())
return int(RasterSRS.GetAttrValue("AUTHORITY", 1))

def get_geotransform(self):
geotransform = None
self._open_dataset()
Expand All @@ -214,64 +214,155 @@ def get_affine(self):
geotransform = self.get_geotransform()
return Affine.from_gdal(*geotransform)

def set_band(self, band_num, array):
if self.exists:
assert(len(array) == self.get_rows)
assert(len(array[0]) == self.get_cols)

self._open_dataset()
def get_bbox(self):
pass

if band_num >= 1 and band_num <= self.dataset.RasterCount:
band = self.dataset.GetRasterBand(band_num)
band.WriteArray(array)
band.FlushCache()
band = None
def set_band(self, masked_array):
'''Currently works for rasters with only one band'''
assert(len(masked_array) == self.get_rows())
assert(len(masked_array[0]) == self.get_cols())
assert(self.band_count() == 1)

self._close_dataset()
else:
raise Exception
uri = self.uri
array = masked_array.data
affine = self.get_affine()
proj = self.get_projection()
datatype = self.get_datatype(1)
nodata_val = masked_array.fill_value
Raster.from_array(uri, array, affine, proj, datatype, nodata_val)

def set_bands(self, array):
if self.exists:
self._open_dataset()
band_count = self.dataset.RasterCount
self._close_dataset()

if band_count == 1 and len(array.shape) == 2:
assert(len(array) == self.get_rows)
assert(len(array[0]) == self.get_cols)
self.set_band(1, array)

elif len(array.shape) == 3 and array.shape[0] == band_count:
for band_num in range(band_count):
self.set_band(band_num + 1, array[band_num])
else:
raise Exception

def clip(self, aoi):
raise NotImplementedError

def reproject(self, proj, resample_method, pixel_size):
raise NotImplementedError

def reclass(self, reclass_table):
# if self.exists:
# self._open_dataset()
# band_count = self.dataset.RasterCount
# self._close_dataset()

# if band_count == 1 and len(array.shape) == 2:
# assert(len(array) == self.get_rows)
# assert(len(array[0]) == self.get_cols)
# self.set_band(1, array)

# elif len(array.shape) == 3 and array.shape[0] == band_count:
# for band_num in range(band_count):
# self.set_band(band_num + 1, array[band_num])
# else:
# raise Exception
raise NotImplementedError

def overlay(self, raster):
raise NotImplementedError
def copy(self, uri):
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
shutil.copyfile(self.uri, uri)
return Raster.from_file(uri)

def is_aligned(self, raster):
raise NotImplementedError
try:
this_affine = self.get_affine()
other_affine = raster.get_affine()
return (this_affine == other_affine)
except:
raise TypeError

def align(self, raster, resample_method):
raise NotImplementedError
'''Currently aligns other raster to this raster - later: union/intersection
'''
assert(self.get_projection() == raster.get_projection())

def dataset_pixel_op(x, y): return y
dataset_uri_list = [self.uri, raster.uri]
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(raster.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(raster.uri)
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)
bounding_box_mode = "dataset"

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
dataset_pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=[resample_method]*2,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_file(dataset_out_uri)
# temp_raster = Raster.from_file(dataset_out_uri)
# temp_raster.copy(raster.uri)
# os.remove(dataset_out_uri)

def clip(self, aoi_uri):
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
pygeo.geoprocessing.clip_dataset_uri(
self.uri, aoi_uri, dataset_out_uri)
return Raster.from_file(dataset_out_uri)

def reproject(self, proj, resample_method, pixel_size=None):
if pixel_size is None:
pixel_size = self.get_affine().a

dataset_out_uri = pygeo.geoprocessing.temporary_filename()
srs = osr.SpatialReference()
srs.ImportFromEPSG(proj)
wkt = srs.ExportToWkt()

pygeo.geoprocessing.reproject_dataset_uri(
self.uri, pixel_size, wkt, resample_method, dataset_out_uri)

return Raster.from_file(dataset_out_uri)

def reclass(self, reclass_table, out_nodata=None):
if out_nodata is None:
out_nodata = pygeo.geoprocessing.get_nodata_from_uri(self.uri)

out_datatype = pygeo.geoprocessing.get_datatype_from_uri(self.uri)
dataset_out_uri = pygeo.geoprocessing.temporary_filename()

pygeo.geoprocessing.reclassify_dataset_uri(
self.uri,
reclass_table,
dataset_out_uri,
out_datatype,
out_nodata)

return Raster.from_file(dataset_out_uri)

def copy(self, uri):
def overlay(self, raster):
raise NotImplementedError

def to_vector(self):
raise NotImplementedError

def local_op(self, raster, pixel_op):
assert(self.is_aligned(raster))

dataset_uri_list = [self.uri, raster.uri]
dataset_out_uri = pygeo.geoprocessing.temporary_filename()
datatype_out = pygeo.geoprocessing.get_datatype_from_uri(raster.uri)
nodata_out = pygeo.geoprocessing.get_nodata_from_uri(raster.uri)
pixel_size_out = pygeo.geoprocessing.get_cell_size_from_uri(self.uri)
bounding_box_mode = "dataset"

pygeo.geoprocessing.vectorize_datasets(
dataset_uri_list,
pixel_op,
dataset_out_uri,
datatype_out,
nodata_out,
pixel_size_out,
bounding_box_mode,
resample_method_list=[resample_method]*2,
dataset_to_align_index=0,
dataset_to_bound_index=0,
assert_datasets_projected=False,
vectorize_op=False)

return Raster.from_file(dataset_out_uri)

def _open_dataset(self):
self.dataset = gdal.Open(self.uri)

Expand Down
2 changes: 1 addition & 1 deletion fauxgeo/raster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from affine import Affine

from fauxgeo.raster import Raster
from fauxgeo.raster import TempRaster
from fauxgeo.temp_raster import TempRaster


class RasterFactory(object):
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ wheel==0.23.0
numpy
affine
gdal
pygeoprocessing
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
"affine",
"gdal",
"wheel",
"matplotlib"
"pygeoprocessing"
]

test_requirements = [
# TODO: put package test requirements here
"numpy",
"affine",
"gdal",
"matplotlib",
"pygeoprocessing",
"wheel",
"nose",
"coverage"
Expand Down

0 comments on commit bce8d31

Please sign in to comment.