In [4]:
import os
os.environ["CUDE_DEVICE_ORDER"] = "PCI_B_US_ID"
os.environ["CUDA_VISIBLE_DEVICES"] ="3"
import rasterio
import numpy as np
from rasterio.plot import adjust_band
import matplotlib.pyplot as plt
from rasterio.plot import reshape_as_raster, reshape_as_image
from rasterio.plot import show
from itertools import product
from rasterio.windows import Window
from pyproj import Proj, transform
import random
import math
import itertools

In [5]:
label_dataset = rasterio.open('/deep_data/landcover_reproject.tif')
label_image = label_dataset.read()

image_paths = ['/deep_data/processed_landsat/LC08_CU_027012_20170907_20181121_C01_V01_SR_combined.tif',
               '/deep_data/processed_landsat/LC08_CU_028012_20140814_20171017_C01_V01_SR_combined.tif',
               '/deep_data/processed_landsat/LC08_CU_028011_20170907_20181130_C01_V01_SR_combined.tif',  
               '/deep_data/processed_landsat/LC08_CU_028012_20171002_20171019_C01_V01_SR_combined.tif']

landsat_datasets = []
for fp in image_paths:
    landsat_datasets.append(rasterio.open(fp))


In [40]:
def tile_generator(image_datasets, label_dataset, tile_height, tile_width, pixel_locations, batch_size):
    ### this is a keras compatible data generator which generates data and labels on the fly 
    ### from a set of pixel locations, a list of image datasets, and a label dataset
    
    # pixel locations looks like [r, c, dataset_index]
    label_image = label_dataset.read()
    label_image[label_image == 255] = 1

    c = r = 0
    i = 0
    
    outProj = Proj(label_dataset.crs)

    # assuming all images have the same num of bands
    band_count = image_datasets[0].count
    class_count = len(np.unique(label_image))
    buffer = math.ceil(tile_height / 2)
  
    while True:
        image_batch = np.zeros((batch_size, tile_height, tile_width, band_count-1)) # take one off because we don't want the QA band
        label_batch = np.zeros((batch_size,class_count))
        b = 0
        while b < batch_size:
            # if we're at the end  of the data just restart
            if i >= len(pixel_locations):
                i=0
            c, r = pixel_locations[i][0]
            dataset_index = pixel_locations[i][1]
            i += 1
            tile = image_datasets[dataset_index].read(list(np.arange(1, band_count+1)), window=Window(c-buffer, r-buffer, tile_width, tile_height))
            if np.amax(tile) == 0: # don't include if it is part of the image with no pixels
                pass
            elif np.isnan(tile).any() == True or -9999 in tile: 
                # we don't want tiles containing nan or -999 this comes from edges
                # this also takes a while and is inefficient
                pass
            elif tile.shape != (band_count, tile_width, tile_height):
                print('wrong shape')
                print(tile.shape)
                # somehow we're randomly getting tiles without the correct dimensions
                pass
            elif np.isin(tile[7,:,:], [352, 368, 392, 416, 432, 480, 840, 864, 880, 904, 928, 944, 1352]).any() == True:
                # make sure pixel doesn't contain clouds
                # this is probably pretty inefficient but only checking width x height for each tile
                # read more here: https://prd-wret.s3-us-west-2.amazonaws.com/assets/palladium/production/s3fs-public/atoms/files/LSDS-1873_US_Landsat_ARD_DFCB_0.pdf
                #print('Found some cloud.')
                #print(tile[7,:,:])
                pass
            else:
                tile = adjust_band(tile[0:7])
                # reshape from raster format to image format
                reshaped_tile = reshape_as_image(tile)

                # find gps of that pixel within the image
                (x, y) = image_datasets[dataset_index].xy(r, c)

                # convert the point we're sampling from to the same projection as the label dataset if necessary
                inProj = Proj(image_datasets[dataset_index].crs)
                if inProj != outProj:
                    x,y = transform(inProj,outProj,x,y)

                # reference gps in label_image
                row, col = label_dataset.index(x,y)

                # find label
                label = label_image[:, row, col]
                # if this label is part of the unclassified area then ignore
                if label == 0 or np.isnan(label).any() == True:
                    pass
                else:
                    # add label to the batch in a one hot encoding style
                    label_batch[b][label] = 1
                    image_batch[b] = reshaped_tile
                    b += 1
        yield (image_batch, label_batch)


In [None]:
def gen_balanced_pixel_locations(image_datasets, label_dataset, amount_of_labels, train_amount, val_amount, test_amount, tile_size):
    ### this function pulls out a train_count + val_count number of random pixels from a list of raster datasets
    ### and returns a list of training pixel locations and image indices 
    ### and a list of validation pixel locations and indices
    

    pixels = []
    label_image = label_dataset.read()
    label_image[label_image == 255] = 1
    outProj = Proj(label_dataset.crs)
    buffer = math.ceil(tile_size/2)
    total_pixels = train_amount + val_amount + test_amount
    total_count_per_dataset = math.ceil(total_pixels / len(image_datasets))
    bucket_size = math.ceil(total_pixels / (amount_of_labels - 7)) # - 7 because not enough pixels per bucket
    label_buckets = np.zeros(amount_of_labels)
    
    for index, image_dataset in enumerate(image_datasets):
        
        img_height, img_width = image_dataset.shape
        points = set()
        #all_points = list(itertools.product(range(0+buffer,img_width-buffer),range(0+buffer,img_height-buffer))) just a test to see whats faster
        while len(points) != total_count_per_dataset:
            #aPoint = all_points.pop(random.randint(0,len(all_points))) test to see what's faster
            aPoint = (random.randint(0+buffer, img_width-buffer), random.randint(0+buffer, img_height-buffer))
            c, r = aPoint
            (x, y) = image_dataset.xy(r, c)
            inProj = Proj(image_dataset.crs)
            if inProj != outProj:
                x,y = transform(inProj,outProj,x,y)
                # reference gps in label_image
            row, col = label_dataset.index(x,y)
            label = label_image[:, row, col]
            if label_buckets[label] != bucket_size and label != 1 and label != 0:
                past_size = len(points)
                points.add(aPoint)
                if past_size != len(points):
                    label_buckets[label] +=1
        points = zip(points, [index]*total_count_per_dataset)
        pixels += points
    random.shuffle(pixels)
    test_px = pixels[:test_amount]
    val_px = pixels[test_amount:(val_amount+test_amount)]
    train_px = pixels[(val_amount+test_amount):]
    return (train_px, val_px, test_px)

In [7]:
label_image = label_dataset.read()
label_image[label_image == 255] = 1
num_classes = len(np.unique(label_image))


# input image dimensions
tile_side = 64
img_rows, img_cols = tile_side, tile_side
img_bands = landsat_datasets[0].count - 1

input_shape = (img_rows, img_cols, img_bands)
print(input_shape)

(64, 64, 7)


In [28]:
(train_px, val_px, test_px) = gen_balanced_pixel_locations(landsat_datasets, label_image, label_dataset, num_classes, 150000, 50000, 10000, tile_side)

In [36]:
import keras
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

Using TensorFlow backend.


In [37]:
model = Sequential()

model.add(Conv2D(tile_side, kernel_size=(3, 3), activation='relu', input_shape=input_shape))
model.add(Conv2D(22, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(128, (3, 3), padding='same', activation='relu'))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

In [38]:
batch_size = 25
epochs = 100
sgd = keras.optimizers.SGD(lr=0.001, decay=1e-6, momentum=0.9, nesterov=True)
metrics=['accuracy']

model.compile(optimizer=sgd, loss='categorical_crossentropy', metrics=metrics)

In [None]:
model.fit_generator(generator=tile_generator(landsat_datasets, label_dataset, tile_side, tile_side, train_px, batch_size), 
                    steps_per_epoch=len(train_px) // batch_size, epochs=epochs, verbose=1,
                    validation_data=tile_generator(landsat_datasets, label_dataset, tile_side, tile_side, val_px, batch_size),
                    validation_steps=len(val_px) // batch_size)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100

In [29]:
#USING BALANCED
train_buckets = np.zeros(num_classes)
for i in range(len(train_px)):
        c, r = train_px[i][0]
        (x, y) = landsat_datasets[train_px[i][1]].xy(r, c) 
        outProj = Proj(label_dataset.crs)
        inProj = Proj(landsat_datasets[train_px[i][1]].crs)
         #print(train_label_buckets)
        if inProj != outProj:
            x,y = transform(inProj,outProj,x,y)
             # reference gps in label_image
        row, col = label_dataset.index(x,y)
        label = label_image[:, row, col]
        train_buckets[label] +=1
print("TRAINING BUCKETS")
print(train_buckets)
val_buckets = np.zeros(num_classes)
for i in range(len(val_px)):
        c, r = val_px[i][0]
        (x, y) = landsat_datasets[val_px[i][1]].xy(r, c) 
        outProj = Proj(label_dataset.crs)
        inProj = Proj(landsat_datasets[val_px[i][1]].crs)
         #print(train_label_buckets)
        if inProj != outProj:
            x,y = transform(inProj,outProj,x,y)
             # reference gps in label_image
        row, col =  label_dataset.index(x,y)
        label = label_image[:, row, col]
        val_buckets[label] +=1
print("VALIDATION BUCKETS")
print(val_buckets)
test_buckets = np.zeros(num_classes)
for i in range(len(test_px)):
        c, r = test_px[i][0]
        (x, y) = landsat_datasets[test_px[i][1]].xy(r, c) 
        outProj = Proj(label_dataset.crs)
        inProj = Proj(landsat_datasets[test_px[i][1]].crs)
         #print(train_label_buckets)
        if inProj != outProj:
            x,y = transform(inProj,outProj,x,y)
             # reference gps in label_image
        row, col =  label_dataset.index(x,y)
        label = label_image[:, row, col]
        test_buckets[label] +=1
print("TEST BUCKETS")
print(test_buckets)

TRAINING BUCKETS
[   0.    0. 3591. 8825. 9356. 9430. 9456. 9330. 9378. 9372. 9268. 9392.
 9459. 9365. 9379. 9432.   10.  350. 5329.  662. 9304. 9294.   18.]
VALIDATION BUCKETS
[0.000e+00 0.000e+00 1.244e+03 2.945e+03 3.154e+03 3.097e+03 3.055e+03
 3.182e+03 3.101e+03 3.115e+03 3.215e+03 3.109e+03 3.049e+03 3.111e+03
 3.116e+03 3.086e+03 4.000e+00 1.170e+02 1.812e+03 2.120e+02 3.089e+03
 3.185e+03 2.000e+00]
