### Supervised classification combined with a segmentation

Last edit: 18.10.2021     
Authors: Yrneh Ulloa-Torrealba, Severin Herzsprung

Set working directory and your relative paths for input and output data.

In [3]:
import os

os.chdir(r"C:\Users\ulloa-to\git\Advanced_Remote_sensing_HM")

In [21]:
# Global variables

src = "W:\Lehre\object_based_classification\python"

# paths to all input and output data folders
geotiff_path = os.path.join(src, "input_geotiff\\")
segmented_geotiff_path = os.path.join(src, "segment_img\\")
vector_path = os.path.join(src, "training_shp\\")
class_path = os.path.join(src, "output_geotiff\\")

# paths to all input and output files
geotiff = os.path.join(geotiff_path, "aoi_02.tif")
segmented_geotiff = os.path.join(segmented_geotiff_path, "segmented_raster.tif")
truth_shp = os.path.join(vector_path, "truth_shp.shp")
train_shp = os.path.join(vector_path, "train.shp")
test_shp = os.path.join(vector_path, "test.shp")
class_tif = os.path.join(class_path, "classified.tif")

# directory, where our logfile is saved:
log_txt = os.path.join(src, 'log.txt')

# variables for segmentation
n_segments_var = 10000
sigma_var = 0
compactness_var = 0.5
max_iter_var = 10

In [2]:
import numpy as np
import gdal
from osgeo import ogr
from skimage import exposure
from skimage.segmentation import slic
from sklearn.ensemble import RandomForestClassifier
from sklearn import metrics
import datetime
import time

import geopandas as gpd
import pandas as pd
import scipy


In [18]:
# read geotiff and extract relevant information

driverTiff = gdal.GetDriverByName("GTiff")
geotiff_ds = gdal.Open(geotiff)
nbands = geotiff_ds.RasterCount
print("bands", geotiff_ds.RasterCount, "rows", 
      geotiff_ds.RasterYSize, "columns", geotiff_ds.RasterXSize)

# store shape info as an array
band_data = []

for i in range(1, nbands): # for 4 bands
# for i in range(1, nbands+1): for the 5 bands
    band = geotiff_ds.GetRasterBand(i).ReadAsArray()
    band_data.append(band)

# numpy.dstack: Stack arrays in sequence depth wise (concatenation along third axis).
band_data = np.dstack(band_data)

# exposure.rescale_intensity: Return image after stretching or shrinking its intensity levels.
# scale image values from 0.0 - 1.0
img = exposure.rescale_intensity(band_data)

bands 4 rows 313 columns 554


In [20]:
# add start datetime to log 
now = datetime.datetime.now()
print ("\n### Start date and time : ", now.strftime("%Y-%m-%d %H:%M:%S"), file=open(log_txt, "a"))


## perform segmentation

print("# Segmentation Start",file=open(log_txt, "a"))

# start time measurement for the segmentation
segmentation_start = time.time()

# apply SLIC and extract (approximately) the supplied number of segments
segments = slic(img, n_segments = n_segments_var, sigma = sigma_var,
                compactness = compactness_var, max_iter = max_iter_var)

# add segmentation time to log
print("# Segmentation done in ", time.time() - segmentation_start, "seconds", file=open(log_txt, "a"))


## extract metadata

print("# Extraction of Metadata Start",file=open(log_txt, "a"))

# start time measurement for the extraction of Metadata
metadata_start = time.time()

# extract this information for the segments I will use for the classification, 
# after I choose the right parameters with the accuracy assessment
def segment_features(segment_pixels):
    features = []
    npixels, nbands = segment_pixels.shape
    for b in range(nbands):
        stats = scipy.stats.describe(segment_pixels[:, b])
        band_stats = list(stats.minmax) + list(stats)[2:]
        if npixels == 1:
            # in this case the variance will be nan and we don't want that
            band_stats[3] = 0.0
        features += band_stats
    return features    

# extract variables from each segment
segment_ids = np.unique(segments)
objects = []
object_ids = []

for id in segment_ids:
    segment_pixels = img[segments == id]
    object_features = segment_features(segment_pixels)
    objects.append(object_features)
    object_ids.append(id)

# add metadata extraction time to log
print("# Extraction of Metadata done in ", time.time() - metadata_start, "seconds", file=open(log_txt, "a"))


## rasterize 

print("# Rasterization Start",file=open(log_txt, "a"))

# start time measurement for rasterization
rasterize_start = time.time()

# save segments to raster
output_fullpath = segmented_geotiff
segments_ds = driverTiff.Create(output_fullpath, geotiff_ds.RasterXSize,
                                geotiff_ds.RasterYSize, 1, gdal.GDT_Float32)
segments_ds.SetGeoTransform(geotiff_ds.GetGeoTransform())
segments_ds.SetProjection(geotiff_ds.GetProjectionRef())
segments_ds.GetRasterBand(1).WriteArray(segments)
segments_ds = None

# rasterize the training data

# read shapefile to geopandas geodataframe
gdf = gpd.read_file(truth_shp)

# get names of land cover classes/labels
class_names = gdf['label'].unique()
print('class names: ', class_names, file=open(log_txt, "a"))

# create a unique id (integer) for each land cover class/label
class_ids = np.arange(class_names.size) + 1
print('class ids: ', class_ids, file=open(log_txt, "a"))

# add a new column to geodatafame with the id for each class/label
gdf['id'] = gdf['label'].map(dict(zip(class_names, class_ids)))

# split the truth data into training and test data sets and save each to a new shapefile
gdf_train = gdf.sample(frac=0.7)
gdf_test = gdf.drop(gdf_train.index)
print('truth data:', gdf.shape, '   train data:', gdf_train.shape,
      '   test data:', gdf_test.shape, file=open(log_txt, "a"))
gdf_train.to_file(train_shp)
gdf_test.to_file(test_shp)

# open the points file to use for training data
train_fn = train_shp
train_ds = ogr.Open(train_fn)
lyr = train_ds.GetLayer()

# create a new raster layer in memory
driver = gdal.GetDriverByName('MEM')
target_ds = driver.Create('', geotiff_ds.RasterXSize, geotiff_ds.RasterYSize, 1, gdal.GDT_UInt16)
target_ds.SetGeoTransform(geotiff_ds.GetGeoTransform())
target_ds.SetProjection(geotiff_ds.GetProjection())

# rasterize the training points
options = ['ATTRIBUTE=id']
gdal.RasterizeLayer(target_ds, [1], lyr, options=options)

# add rasterization time to log
print("# Rasterization done in ", time.time() - rasterize_start, "seconds", file=open(log_txt, "a"))


## classify

print("# Classification Start",file=open(log_txt, "a"))

# start time measurement for classification
classify_start = time.time()

ground_truth = target_ds.GetRasterBand(1).ReadAsArray()
classes = np.unique(ground_truth)[1:]

# Get segments representing each land cover classification type and ensure no segment represents more than one class.
segments_per_class = {}
for klass in classes:
    segments_of_class = segments[ground_truth == klass]
    segments_per_class[klass] = set(segments_of_class)
    print("Training segments for class", klass, ":", len(segments_of_class), file=open(log_txt, "a"))

intersection = set()
accum = set()
 
for class_segments in segments_per_class.values():
    intersection |= accum.intersection(class_segments)
    accum |= class_segments
assert len(intersection) == 0, "Segment(s) represent multiple classes"

# Classify the image
train_img = np.copy(segments)
threshold = train_img.max() + 1
 
for klass in classes:
    class_label = threshold + klass
    for segment_id in segments_per_class[klass]:
        train_img[train_img == segment_id] = class_label

train_img[train_img <= threshold] = 0
train_img[train_img > threshold] -= threshold
 
training_objects = []
training_labels = []
 
for klass in classes:
    class_train_object = [v for i, v in enumerate(objects) if segment_ids[i] in segments_per_class[klass]]
    training_labels += [klass] * len(class_train_object)
    training_objects += class_train_object
    print('Training objects for class', klass, ':', len(class_train_object), file=open(log_txt, "a"))

# fit Random Forest classifier
classifier = RandomForestClassifier(n_jobs=-1)
classifier.fit(training_objects, training_labels)

# predict classifications
predicted = classifier.predict(objects)

# apply prediction to numpy array
clf = np.copy(segments)
for segment_id, klass in zip(segment_ids, predicted):
    clf[clf == segment_id] = klass

mask = np.sum(img, axis=2)
mask[mask > 0.0] = 1.0
mask[mask == 0.0] = -1.0
clf = np.multiply(clf, mask)
clf[clf < 0] = -9999.0
 
#Saving classificaiton to raster with gdal
clfds = driverTiff.Create(class_tif, geotiff_ds.RasterXSize, geotiff_ds.RasterYSize,
                          1, gdal.GDT_Float32)
clfds.SetGeoTransform(geotiff_ds.GetGeoTransform())
clfds.SetProjection(geotiff_ds.GetProjection())
clfds.GetRasterBand(1).SetNoDataValue(-9999.0)
clfds.GetRasterBand(1).WriteArray(clf)
clfds = None
 
# add classification time to log
print("# Classification done in ", time.time() - rasterize_start, "seconds", file=open(log_txt, "a"))


## accuracy assessment

print("# Accuracy assessment Start",file=open(log_txt, "a"))

# start time measurement for the accuracy assessment
accu_start = time.time()

# open the points file to use for test data
test_fn = test_shp
test_ds = ogr.Open(test_fn)
lyr = test_ds.GetLayer()

# create a new raster layer in memory
driver = gdal.GetDriverByName('MEM')
target_ds = driver.Create('', geotiff_ds.RasterXSize, geotiff_ds.RasterYSize, 1, gdal.GDT_UInt16)
target_ds.SetGeoTransform(geotiff_ds.GetGeoTransform())
target_ds.SetProjection(geotiff_ds.GetProjection())

# rasterize the test points
options = ['ATTRIBUTE=id']
gdal.RasterizeLayer(target_ds, [1], lyr, options=options)

# set test data as truth
truth = target_ds.GetRasterBand(1).ReadAsArray()

# open classified image and set as prediction
pred_ds = gdal.Open(class_tif)
pred = pred_ds.GetRasterBand(1).ReadAsArray()

idx = np.nonzero(truth)

# create confusion matrix
cm = metrics.confusion_matrix(truth[idx], pred[idx])
 
# pixel accuracy
print("Confusion matrix: ",'\n',cm, file=open(log_txt, "a"))
 
print("Diagonal: ",cm.diagonal(), file=open(log_txt, "a"))
print("Sum: ",cm.sum(axis=0), file=open(log_txt, "a"))
 
accuracy = cm.diagonal() / cm.sum(axis=0)
print("Accuracy: ",accuracy, file=open(log_txt, "a"))

# add accuracy assessment time to log
print("# Accuracy assessment done in ", time.time() - accu_start, "seconds", file=open(log_txt, "a"))

# add end datetime to log
now = datetime.datetime.now()
print ("### End date and time: ", now.strftime("%Y-%m-%d %H:%M:%S"), file=open(log_txt, "a"))

print("Done")



  


Done
