In [23]:
import os
os.environ["CUDE_DEVICE_ORDER"] = "PCI_B_US_ID"
os.environ["CUDA_VISIBLE_DEVICES"] ="3"
import rasterio
import numpy as np
from rasterio.plot import adjust_band
import matplotlib.pyplot as plt
from rasterio.plot import reshape_as_raster, reshape_as_image
from rasterio.plot import show
from itertools import product
from rasterio.windows import Window
from pyproj import Proj, transform
import random
import math
import os
import skimage.io as io
import skimage.transform as trans
from keras.models import *
from keras.layers import *
from keras.optimizers import *
from keras.callbacks import ModelCheckpoint, LearningRateScheduler
from keras import backend as keras
import datagenerator as dg
from unet_model import unet_model

In [2]:
label_dataset = rasterio.open('/deep_data/landcover_reproject.tif')
label_image = label_dataset.read()

image_paths = ['/deep_data/processed_landsat/LC08_CU_027012_20170907_20181121_C01_V01_SR_combined.tif',
               '/deep_data/processed_landsat/LC08_CU_028012_20140814_20171017_C01_V01_SR_combined.tif',
               '/deep_data/processed_landsat/LC08_CU_028011_20170907_20181130_C01_V01_SR_combined.tif',  
               '/deep_data/processed_landsat/LC08_CU_028012_20171002_20171019_C01_V01_SR_combined.tif']

landsat_datasets = []
for fp in image_paths:
    landsat_datasets.append(rasterio.open(fp))


In [15]:
def fcn_tile_generator(image_datasets, label_dataset, tile_height, tile_width, pixel_locations, batch_size):
    ### this is a keras compatible data generator which generates data and labels on the fly 
    ### from a set of pixel locations, a list of image datasets, and a label dataset
    # pixel locations looks like [r, c, dataset_index]
    label_image = label_dataset.read()
    label_image[label_image == 255] = 1
    c = r = 0
    i = 0
    outProj = Proj(label_dataset.crs)
    # assuming all images have the same num of bands
    band_count = image_datasets[0].count
    class_count = len(np.unique(label_image))
    buffer = math.ceil(tile_height / 2)
  
    while True:
        image_batch = np.zeros((batch_size, tile_height, tile_width, band_count-1)) # take one off because we don't want the QA band
        label_batch = np.zeros((batch_size, class_count, tile_height, tile_width))
        b = 0
        while b < batch_size:
            # if we're at the end  of the data just restart
            if i >= len(pixel_locations):
                i=0
            #GET PIXELS
            c, r = pixel_locations[i][0]
            dataset_index = pixel_locations[i][1]
            i += 1
            #TILE PROCESSING
            tile = image_datasets[dataset_index].read(list(np.arange(1, band_count+1)), window=Window(c-buffer, r-buffer, tile_width, tile_height))
            if np.amax(tile) == 0: # don't include if it is part of the image with no pixels
                pass
            elif np.isnan(tile).any() == True or -9999 in tile: 
                # we don't want tiles containing nan or -999 this comes from edges
                # this also takes a while and is inefficient
                pass
            elif tile.shape != (band_count, tile_width, tile_height):
                print('wrong shape')
                print(tile.shape)
                # somehow we're randomly getting tiles without the correct dimensions
                pass
            elif np.isin(tile[7,:,:], [352, 368, 392, 416, 432, 480, 840, 864, 880, 904, 928, 944, 1352]).any() == True:
                # make sure pixel doesn't contain clouds
                # this is probably pretty inefficient but only checking width x height for each tile
                # read more here: https://prd-wret.s3-us-west-2.amazonaws.com/assets/palladium/production/s3fs-public/atoms/files/LSDS-1873_US_Landsat_ARD_DFCB_0.pdf
                #print('Found some cloud.')
                #print(tile[7,:,:])
                pass
            else:
                tile = adjust_band(tile[0:7])
                # reshape from raster format to image format
                train_tile = reshape_as_image(tile)
                # LABEL TILE PROCESSING
                #Transforms train pixel location to equivalent label pixel location
                outProj = Proj(label_dataset.crs)
                inProj = Proj(image_datasets[dataset_index].crs)
                (x, y) = image_datasets[dataset_index].xy(r, c)
                if inProj != outProj:
                    x,y = transform(inProj,outProj,x,y) 
                #use pixel to create tile
                row, col = label_dataset.index(x, y)
                label_tile = label_dataset.read(1, window=Window(row-buffer, col-buffer, tile_width, tile_height))
                label_masks = np.zeros((tile_height, tile_width, class_count))
                #use tile to make the masks
                for i in range(class_count):
                    for h in range(tile_height):
                        tileRow = row-buffer+h
                        for w in range(tile_width):
                            tileCol = col-buffer+w
                            if(label_image[0, tileRow, tileCol] == i):
                                label_masks[i][h][w] = 1
                            else:
                                label_masks[i][h][w] = 0               
                label_batch[b] = label_masks;
                image_batch[b] = train_tile
                b += 1
                yield (image_batch, label_batch)
                    
        

In [17]:
batch_size = 25
label_image[label_image == 255] = 1
num_classes = len(np.unique(label_image))
epochs = 50

# input image dimensions
tile_side = 64
img_rows, img_cols = tile_side, tile_side
img_bands = landsat_datasets[0].count - 1

input_shape = (img_rows, img_cols, img_bands)
print(input_shape)

(64, 64, 7)


In [18]:
train_px, val_px = dg.gen_pixel_locations(landsat_datasets, 100, 50, tile_side)

In [29]:
weight = 1/23
weight_list = []
for i in range(num_classes):
    weight_list.append(weight)
    
model = unet_model(n_classes=num_classes, im_sz=tile_side, n_channels=7, n_filters_start=32, growth_factor=2, upconv=True, class_weights=weight_list)

In [30]:
model.fit_generator(generator=fcn_tile_generator(landsat_datasets, label_dataset, tile_side, tile_side, train_px, batch_size), 
                    steps_per_epoch=len(train_px) // batch_size, epochs=epochs, verbose=1)

                    
#validation_data=fcn_tile_generator(landsat_datasets, label_dataset, tile_side, tile_side, val_px, batch_size),
#validation_steps=len(val_px) // batch_size)


Epoch 1/50


InvalidArgumentError: Incompatible shapes: [25,64,64,23] vs. [25,23,64,64]
	 [[{{node loss_3/conv2d_97_loss/logistic_loss/mul}} = Mul[T=DT_FLOAT, _class=["loc:@train...ad/Reshape"], _device="/job:localhost/replica:0/task:0/device:GPU:0"](loss_3/conv2d_97_loss/Log-2-0-TransposeNCHWToNHWC-LayoutOptimizer, _arg_conv2d_97_target_0_1/_1883)]]
	 [[{{node loss_3/mul/_1921}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_5906_loss_3/mul", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]