Skip to content

Commit

Permalink
work on raster class
Browse files Browse the repository at this point in the history
  • Loading branch information
Will Bierbower authored and Will Bierbower committed Mar 25, 2015
1 parent ea4a965 commit 6c68559
Show file tree
Hide file tree
Showing 7 changed files with 236 additions and 158 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ A python library that generates simple OSGeo-supported rasters and vectors. The
Requirements
------------

fauxgeo 0.1.1 requires
fauxgeo 0.1.3 requires

* NumPy
* Matplotlib
* GDAL
* affine == 1.0

Expand Down
2 changes: 1 addition & 1 deletion fauxgeo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from raster import *

__version__ = '0.1.2'
__version__ = '0.1.3'
241 changes: 90 additions & 151 deletions fauxgeo/raster.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,40 @@
'''
Raster, TestRaster, and RasterFactory Classes
Raster Class
'''

import tempfile
import os
import gdal
import osr
import numpy as np
from affine import Affine
from StringIO import StringIO
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from matplotlib import image


class Raster(object):
# any global variables here

def __init__(self, uri):
self.exists = False
def __init__(self, uri, driver):
self.uri = uri
self.driver = driver
self.dataset = None
self.tmp = False

if uri is not None:
if os.path.exists(uri):
self.exists = True
else:
self.exists = False
# os.makedirs(uri)
f = open(uri, 'w')
f.close()
else:
# make temporary raster file
self.exists = False
self.tmp = True
tmpfile = tempfile.NamedTemporaryFile(mode='r')
self.uri = tmpfile.name
tmpfile.close()

def __str__(self):
return "raster at %s" % (self.uri)

def __del__(self):
self._close()

def __exit__(self):
self._close()

# Should probably be in TestRaster class
def _close(self):
if self.tmp is True:
os.remove(self.uri)

def init(self, array, affine, proj, datatype, nodata_val):
# if self.exists == True:
# raise Exception

if len(array.shape) == 2:
@classmethod
def from_array(self, output_uri, array, affine, proj, datatype, nodata_val, driver='GTiff'):
if len(array.shape) is 2:
num_bands = 1
elif len(array.shape) == 3:
elif len(array.shape) is 3:
num_bands = len(array)
else:
raise ValueError

rows = array.shape[0]
cols = array.shape[1]

driver = gdal.GetDriverByName('GTiff')
dataset = driver.Create(self.uri, cols, rows, num_bands, datatype)
dataset.SetGeoTransform((
affine.c, affine.a, affine.b, affine.f, affine.d, affine.e))
driver = gdal.GetDriverByName(driver)
dataset = driver.Create(output_uri, cols, rows, num_bands, datatype)
dataset.SetGeoTransform((affine.to_gdal()))

for band_num in range(num_bands):
band = dataset.GetRasterBand(band_num + 1) # Get only raster band
Expand All @@ -80,47 +48,66 @@ def init(self, array, affine, proj, datatype, nodata_val):
dataset.SetProjection(dataset_srs.ExportToWkt())
band.FlushCache()

self.exists = True
band = None
dataset_srs = None
dataset = None
driver = None

def init_2(self, array, bot_left_x, bot_left_y, pix_width, proj, datatype, nodata_val):
return Raster(output_uri, driver=driver)

if len(array.shape) == 2:
num_bands = 1
elif len(array.shape) == 3:
num_bands = len(array)
else:
raise ValueError
@classmethod
def from_file(self, uri, driver='GTiff'):
if not os.path.isabs(uri):
uri = os.path.join(os.getcwd(), uri)
# assert existence
return Raster(uri, driver)

rows = array.shape[0]
cols = array.shape[1]
vertical_offset = rows * pix_width
@classmethod
def simple_affine(self, top_left_x, top_left_y, pix_width, pix_height):
return Affine(pix_width, 0, top_left_x, 0, -(pix_height), top_left_y)

driver = gdal.GetDriverByName('GTiff')
dataset = driver.Create(self.uri, cols, rows, num_bands, datatype)
dataset.SetGeoTransform((
bot_left_x, pix_width, 0, (bot_left_y + vertical_offset), 0, -(pix_width)))
def __str__(self):
return self.uri

for band_num in range(num_bands):
band = dataset.GetRasterBand(band_num + 1) # Get only raster band
band.SetNoDataValue(nodata_val)
if num_bands > 1:
band.WriteArray(array[band_num])
else:
band.WriteArray(array)
dataset_srs = osr.SpatialReference()
dataset_srs.ImportFromEPSG(proj)
dataset.SetProjection(dataset_srs.ExportToWkt())
band.FlushCache()
def __len__(self):
return self.band_count()

self.exists = True
band = None
dataset_srs = None
dataset = None
driver = None
def __eq__(self):
pass # false if different shape, element-by-element if same shape

def __getitem__(self):
pass # return numpy slice? Raster object with sliced numpy array?

def __setitem__(self):
pass # set numpy values to raster

def __getslice__(self):
pass

def __setslice__(self):
pass

def __iter__(self):
pass # iterate over bands?

def __contains__(self):
pass # test numpy raster against all bands?

def _figure_data(self, format):
f = StringIO()
array = self.get_band(1)
plt.imsave(f, array, cmap=cm.Greys_r)
f.seek(0)
return f.read()

def _repr_png_(self):
return self._figure_data('png')

def band_count(self):
self._open_dataset()
count = self.dataset.RasterCount
self._close_dataset()
return count

def get_band(self, band_num):
a = None
Expand All @@ -129,6 +116,8 @@ def get_band(self, band_num):
if band_num >= 1 and band_num <= self.dataset.RasterCount:
band = self.dataset.GetRasterBand(band_num)
a = band.ReadAsArray()
nodata_val = band.GetNoDataValue()
a = np.ma.masked_equal(a, nodata_val)
band = None
else:
pass
Expand All @@ -142,10 +131,13 @@ def get_bands(self):
if self.dataset.RasterCount == 0:
return None

a = np.zeros((self.dataset.RasterYSize, self.dataset.RasterXSize, self.dataset.RasterCount))
for num in arange(self.dataset.RasterCount):
a = np.zeros((self.dataset.RasterCount, self.dataset.RasterYSize, self.dataset.RasterXSize))
for num in np.arange(self.dataset.RasterCount):
band = self.dataset.GetRasterBand(num+1)
a[:, :, num] = band.ReadAsArray()
b = band.ReadAsArray()
nodata_val = band.GetNoDataValue()
b = np.ma.masked_equal(b, nodata_val)
a[num] = b

self._close_dataset()
return a
Expand Down Expand Up @@ -242,82 +234,29 @@ def set_bands(self, array):
else:
raise Exception

def _open_dataset(self):
self.dataset = gdal.Open(self.uri)

def _close_dataset(self):
self.dataset = None

def clip(self, aoi):
raise NotImplementedError

class TestRaster(Raster):
def reproject(self, proj, resample_method, pixel_size):
raise NotImplementedError

def __init__(self, array, affine, proj, datatype, nodata_val):
super(TestRaster, self).__init__(None)
self.init(array, affine, proj, datatype, nodata_val)
def reclass(self, reclass_table):
raise NotImplementedError

def __del__(self):
self._close()
def overlay(self, raster):
raise NotImplementedError

def __exit__(self):
self._close()
def is_aligned(self, raster):
raise NotImplementedError

def _close(self):
os.remove(self.uri)
def align(self, raster, resample_method):
raise NotImplementedError

def copy(self, uri):
raise NotImplementedError

class RasterFactory(object):

def __init__(self, proj, datatype, nodata_val, rows, cols, affine=Affine.identity):
self.proj = proj
self.datatype = datatype
self.nodata_val = nodata_val
self.rows = rows
self.cols = cols
self.affine = affine

def get_metadata(self):
meta = {}
meta['proj'] = self.proj
meta['datatype'] = self.datatype
meta['nodata_val'] = self.nodata_val
meta['rows'] = self.rows
meta['cols'] = self.cols
meta['affine'] = self.affine
return meta
def _open_dataset(self):
self.dataset = gdal.Open(self.uri)

def _create_raster(self, array, uri):
if uri is None:
return TestRaster(array, self.affine, self.proj, self.datatype, self.nodata_val)
else:
r = Raster(uri)
r.init(array, self.affine, self.proj, self.datatype, self.nodata_val)
return r

def uniform(self, val, uri=None):
a = np.ones((self.rows, self.cols)) * val
return self._create_raster(a, uri)

def alternating(self, val1, val2, uri=None):
a = np.ones((self.rows, self.cols)) * val2
a[::2, ::2] = val1
a[1::2, 1::2] = val1
return self._create_raster(a, uri)

def random(self, uri=None):
a = np.random.rand(self.rows, self.cols)
return self._create_raster(a, uri)

def horizontal_ramp(self, val1, val2, uri=None):
a = np.zeros((self.rows, self.cols))
col_vals = np.linspace(val1, val2, self.cols)
a[:] = col_vals
return self._create_raster(a, uri)

def vertical_ramp(self, val1, val2, uri=None):
a = np.zeros((self.cols, self.rows))
row_vals = np.linspace(val1, val2, self.rows)
a[:] = row_vals
a = a.T
return self._create_raster(a, uri)

# def bell_shape(self, uri=None):
def _close_dataset(self):
self.dataset = None
67 changes: 67 additions & 0 deletions fauxgeo/raster_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
'''
RasterFactory Class
'''

import numpy as np
from affine import Affine

from fauxgeo.raster import Raster
from fauxgeo.raster import TempRaster


class RasterFactory(object):

def __init__(self, proj, datatype, nodata_val, rows, cols, affine=Affine.identity):
self.proj = proj
self.datatype = datatype
self.nodata_val = nodata_val
self.rows = rows
self.cols = cols
self.affine = affine

def get_metadata(self):
meta = {}
meta['proj'] = self.proj
meta['datatype'] = self.datatype
meta['nodata_val'] = self.nodata_val
meta['rows'] = self.rows
meta['cols'] = self.cols
meta['affine'] = self.affine
return meta

def _create_raster(self, array, uri):
if uri is None:
return TempRaster.from_array(
array, self.affine, self.proj, self.datatype, self.nodata_val)
else:
return Raster.from_file(
uri, array, self.affine, self.proj, self.datatype, self.nodata_val)

def uniform(self, val, uri=None):
a = np.ones((self.rows, self.cols)) * val
return self._create_raster(a, uri)

def alternating(self, val1, val2, uri=None):
a = np.ones((self.rows, self.cols)) * val2
a[::2, ::2] = val1
a[1::2, 1::2] = val1
return self._create_raster(a, uri)

def random(self, uri=None):
a = np.random.rand(self.rows, self.cols)
return self._create_raster(a, uri)

def horizontal_ramp(self, val1, val2, uri=None):
a = np.zeros((self.rows, self.cols))
col_vals = np.linspace(val1, val2, self.cols)
a[:] = col_vals
return self._create_raster(a, uri)

def vertical_ramp(self, val1, val2, uri=None):
a = np.zeros((self.cols, self.rows))
row_vals = np.linspace(val1, val2, self.rows)
a[:] = row_vals
a = a.T
return self._create_raster(a, uri)

# def bell_shape(self, uri=None):

0 comments on commit 6c68559

Please sign in to comment.