## Shrub island identification
##### We utilized a slightly modified U-Net architecture to segment shrub islands relying on data augmentation, one of the well-recogonized image segmentation algorithm, for our shrub island identification to use the available annotated samples more efficiently. In this model, we used ResNet34 as backbone for U-Net.
##### The tensorflow and keras platform were selected to achieve the training, validation, and prediction of U-net

## .......Python Package.......

In [None]:
import random
import gdal
from pandas import DataFrame
from tqdm import tqdm
import pandas as pd
import os
import argparse
import configparser
import numpy as np
import cv2
import rasterio
import six
import tensorflow as tf
import datetime
import h5py
import io
import math
from keras import Model
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from keras.models import load_model
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate, Dropout,BatchNormalization
from keras.layers import Conv2D, Concatenate, MaxPooling2D
from keras.layers import UpSampling2D, Dropout, BatchNormalization
from tqdm import tqdm_notebook
from keras import initializers
from keras import regularizers
from keras import constraints
from keras.utils import conv_utils
from keras.utils.data_utils import get_file
from keras.engine.topology import get_source_inputs
from keras.engine import InputSpec
from keras import backend as K
from keras.applications.imagenet_utils import _obtain_input_shape
from keras.regularizers import l2
from keras.engine.topology import Input
from keras.engine.training import Model
from keras.layers.convolutional import Conv2D, UpSampling2D, Conv2DTranspose
from keras.layers.core import Activation, SpatialDropout2D
from keras.layers.merge import concatenate,add
from keras.layers.normalization import BatchNormalization
from keras.layers.pooling import MaxPooling2D
from keras import losses

#### Step1: Sample from remote sensing image and label image
##### The sliding window algorithm was adopted to generate sample patches from both original GF-2 image and annotated raster image. In this algorithm, the size of sample patch was set as 128 × 128 pixels, and sample interval was set to 80 pixels. We assigned 80% and 20% of sampled patches in three blocks as training set and validation set.

In [None]:
def smoothed_generate_train_dataset(CropSize,stepsize, images_path,labels_path,masks_path, train_image_path,train_label_path,validation_image_path,validation_label_path,train_csv, validation_csv):
    g_count = 1 
    #image_each = image_num // len(images_path)
    image_train, label_train = [], []
    image_validation, label_validation = [], []
    for i in range(len(images_path)):
        
        # image
        dataset_img = gdal.Open(images_path[i])
        width = dataset_img.RasterXSize
        height = dataset_img.RasterYSize
        proj = dataset_img.GetProjection()
        geotrans = dataset_img.GetGeoTransform()
        image = dataset_img.ReadAsArray(0,0,width,height)
        
        # label
        dataset_label = gdal.Open(labels_path[i])
        label = dataset_label.ReadAsArray(0,0,width,height)
        
        # mask
        dataset_mask = gdal.Open(masks_path[i])
        mask = dataset_mask.ReadAsArray(0,0,width,height)

        
        for m in np.arange(0, height-CropSize-1, stepsize): 
            for n in np.arange(0, width-CropSize-1, stepsize):
                
                mask_d = mask[m : m + CropSize,n : n + CropSize]
                if np.min(mask_d) == 1:
                    image_d = image[:,m : m + CropSize,n : n + CropSize] 
                    label_d = label[m : m + CropSize,n : n + CropSize]
                    
                    # print (train_label_path+'%05d.tif' % g_count)
                    if g_count%5 == 0:
                        image_validation.append(validation_image_path+'%05d.tif' % g_count)    
                        label_validation.append(validation_label_path+'%05d.tif' % g_count)
                        writeTiff(image_d, geotrans, proj, validation_image_path+'%05d.tif' % g_count)
                        writeTiff(label_d, geotrans, proj, validation_label_path+'%05d.tif' % g_count)
                    else:
                        image_train.append(train_image_path+'%05d.tif' % g_count)   
                        label_train.append(train_label_path+'%05d.tif' % g_count)
                        writeTiff(image_d, geotrans, proj, train_image_path+'%05d.tif' % g_count)
                        writeTiff(label_d, geotrans, proj, train_label_path+'%05d.tif' % g_count)
                    
                    g_count += 1
    
    print (len(image_train),len(image_validation))    
    df1 = pd.DataFrame({'image':image_train, 'label':label_train})  # type: DataFrame
    df2 = pd.DataFrame({'image':image_validation, 'label':label_validation})  # type: DataFrame
    df1.to_csv(train_csv, index=False)
    df2.to_csv(validation_csv, index=False)
                 
#### Save as tif
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    #创建文件
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if(dataset!= None):
        dataset.SetGeoTransform(im_geotrans) 
        dataset.SetProjection(im_proj)
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset
    
if __name__=='__main__':
    # set parameter  
    CropSize = 128
    stepsize = 80
    images_path = ['E:/CNN/GF2NEW/sample/image/zone1_mq_002.tif','E:/CNN/GF2NEW/sample/image/zone1_mq_003.tif','E:/CNN/GF2NEW/sample/image/zone2_nm_002.tif']
    labels_path = ['E:/CNN/GF2NEW/sample/label/zone1_mq_002_label.tif','E:/CNN/GF2NEW/sample/label/zone1_mq_003_label.tif','E:/CNN/GF2NEW/sample/label/zone2_nm_002_label.tif']
    masks_path = ['E:/CNN/GF2NEW/sample/mask/zone1_mq_002_mask.tif','E:/CNN/GF2NEW/sample/mask/zone1_mq_003_mask.tif','E:/CNN/GF2NEW/sample/mask/zone2_nm_002_mask.tif']
    train_image_path = 'E:/CNN/GF2NEW/sample/sample_result/image/'
    train_label_path ='E:/CNN/GF2NEW/sample/sample_result/label/'
    validation_image_path = 'E:/CNN/GF2NEW/sample/sample_result/image_val/'
    validation_label_path ='E:/CNN/GF2NEW/sample/sample_result/label_val/'
    train_csv = "E:/CNN/GF2NEW/sample/sample_result/train.csv"
    validation_csv = "E:/CNN/GF2NEW/sample/sample_result/validation.csv"
    smoothed_generate_train_dataset(CropSize,stepsize, images_path,labels_path,masks_path, train_image_path,train_label_path,validation_image_path,validation_label_path,train_csv, validation_csv) 
    print ('finished')

#### Step2: Data augmentation
##### Data augmentation is a technique to improve the generalization performance of trained convolutional neural networks. We used five transformations to extend sample data of training and validation set: horizontal flip, vertical flip and diagonal mirroring, rotation (90 and 270).

In [None]:
# read image
def GDALreadTif(fileName, xoff = 0, yoff = 0, data_width = 0, data_height = 0):
    dataset = gdal.Open(fileName)
    if dataset == None:
        print(fileName + "Can not open file")
    #  width of image
    width = dataset.RasterXSize 
    #  height of image
    height = dataset.RasterYSize 
    #  band number of image
    bands = dataset.RasterCount 
    #  import image
    if(data_width == 0 and data_height == 0):
        data_width = width
        data_height = height
    data = dataset.ReadAsArray(xoff, yoff, data_width, data_height)
    #  get GeoTransform
    geotrans = dataset.GetGeoTransform()
    #  get Projection
    proj = dataset.GetProjection()
    return width, height, bands, data, geotrans, proj

#  write file with 
def writeTiff(im_data, im_geotrans, im_proj, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    if(dataset!= None):
        dataset.SetGeoTransform(im_geotrans)
        dataset.SetProjection(im_proj)
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset

def rotate(img, angle, scale=1.0):
    if len(img.shape)==3:
        img = np.transpose(img, axes=(1,2,0))
    else:
        img = img
    height, width = img.shape[:2]  
    center = (width / 2, height / 2)   
    M = cv2.getRotationMatrix2D(center, angle, scale) 
    rotated = cv2.warpAffine(img, M, (height, width))
    return rotated

# Data augmentation from generated sample patches for training

train_csv = "E:/CNN/GF2NEW/sample/sample_result/train.csv"
train_path = pd.read_csv(train_csv)

imageList = train_path["image"]
labelList = train_path["label"]

image_array = np.array(imageList)
image_list =image_array.tolist()
label_array = np.array(labelList)
label_list =label_array.tolist()

tran_num = len(imageList)+1

for i in range(len(imageList)):#len(imageList)
    # read file
    img_file = imageList[i]
    im_width, im_height, im_bands, im_data, im_geotrans, im_proj = GDALreadTif(img_file)
    label_file = labelList[i]
    label = cv2.imread(label_file,-1)
 
    #  horizontal flip
    im_data_hor = np.flip(im_data, axis = 2)
    hor_path = imageList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(im_data_hor, im_geotrans, im_proj, hor_path)

    Hor = cv2.flip(label, 1)
    hor_path1 = labelList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(Hor, im_geotrans, im_proj, hor_path1)
    image_list.append(hor_path)
    label_list.append(hor_path1)
    tran_num += 1
 
    #  vertical flip
    im_data_vec = np.flip(im_data, axis = 1)
    vec_path = imageList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(im_data_vec, im_geotrans, im_proj, vec_path)

    Vec = cv2.flip(label, 0)
    vec_path1 = labelList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(Vec, im_geotrans, im_proj, vec_path1)
    image_list.append(vec_path)
    label_list.append(vec_path1)
    tran_num += 1
            
    #  diagonal mirroring
    im_data_dia = np.flip(im_data_vec, axis = 2)
    dia_path = imageList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(im_data_dia, im_geotrans, im_proj, dia_path)
    
    Dia = cv2.flip(label, -1)
    dia_path1= labelList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(Dia, im_geotrans, im_proj, dia_path1)
    image_list.append(dia_path)
    label_list.append(dia_path1)
    tran_num += 1
    
    # rotation90
    ro_img_90 = rotate(im_data, 90, scale=1.0)
    ro_lab_90 = rotate(label, 90, scale=1.0)
    ro90 = np.transpose(ro_img_90, axes=(2,0,1))
    ro90_img_path = imageList[i][:-9]+ '%05d.tif' %tran_num
    ro90_lab_path = labelList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(ro90, im_geotrans, im_proj,  ro90_img_path)
    writeTiff(ro_lab_90, im_geotrans, im_proj, ro90_lab_path)
    image_list.append(ro90_img_path)
    label_list.append(ro90_lab_path)
    tran_num += 1
    
    # rotation270
    ro_img_270 = rotate(im_data, 270, scale=1.0)
    ro_lab_270 = rotate(label, 270, scale=1.0)
    ro270 = np.transpose(ro_img_270, axes=(2,0,1))
    ro270_img_path = imageList[i][:-9]+ '%05d.tif' %tran_num
    ro270_lab_path = labelList[i][:-9]+ '%05d.tif' %tran_num
    writeTiff(ro270, im_geotrans, im_proj,  ro270_img_path)
    writeTiff(ro_lab_270, im_geotrans, im_proj, ro270_lab_path)
    image_list.append(ro270_img_path)
    label_list.append(ro270_lab_path)
    tran_num += 1
  
print (len(image_list))
print (len(label_list))                
df = pd.DataFrame({'image':image_list, 'label':label_list})  # type: DataFrame
df.to_csv("E:/CNN/GF2NEW/sample/sample_result/train_ano.csv", index=False)   

#### Step3: Data processes
##### This step includes data read, color_dict for both RGB and Gray. in our model the colorDict_GRAY = np.array([0],[1])

In [None]:
# Data read
def readTif(fileName):
    img = rasterio.open(fileName)
    Img_data = img.read()
    data = np.transpose(Img_data, axes=(1,2,0))
    return data

In [None]:
# This code is adopted for color_dict
# In our model, this process was not executed because the colorDict_GRAY = np.array([0],[1])
# labelFolder
# classNum 

def color_dict(labelFolder, classNum):
    colorDict = []
    ImageNameList = os.listdir(labelFolder)
    for i in range(len(ImageNameList)):
        ImagePath = labelFolder + "/" + ImageNameList[i]
        img = cv2.imread(ImagePath,-1).astype(np.uint8)

        if(len(img.shape) == 2):
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB).astype(np.uint32)
        img_new = img[:,:,0] * 1000000 + img[:,:,1] * 1000 + img[:,:,2]
        unique = np.unique(img_new)
        for j in range(unique.shape[0]):
            colorDict.append(unique[j])
        colorDict = sorted(set(colorDict))
        if(len(colorDict) == classNum):
            break
    colorDict_RGB = []
    for k in range(len(colorDict)):
        color = str(colorDict[k]).rjust(9, '0')
        color_RGB = [int(color[0 : 3]), int(color[3 : 6]), int(color[6 : 9])]
        colorDict_RGB.append(color_RGB)
    colorDict_RGB = np.array(colorDict_RGB)
    colorDict_GRAY = colorDict_RGB.reshape((colorDict_RGB.shape[0], 1 ,colorDict_RGB.shape[1])).astype(np.uint8)
    colorDict_GRAY = cv2.cvtColor(colorDict_GRAY, cv2.COLOR_BGR2GRAY)
    return colorDict_RGB, colorDict_GRAY

In [None]:
#  Data processes：Image normalization + onehot encoding
#  img data of image
#  label data of label
#  classNum: number of classes,2 
#  colorDict_GRAY = np.array([[0],[1]])
def dataPreprocess(img, label, classNum, colorDict_GRAY):
    # img = (img - img.min((0,1))) / (img.max((0,1)) + img.min((0,1)))
    for i in range(colorDict_GRAY.shape[0]):
        label[label == colorDict_GRAY[i][0]] = i
    new_label = np.zeros(label.shape + (classNum,))
    for i in range(classNum):
        new_label[label == i,i] = 1                                          
    label = new_label
    return (img, label)

#### Step4: Training dataset and validation dataset generator

In [None]:
#  batch_size: size of batch
#  train_image_path: path of training images
#  train_label_path path of labelled images
#  classNum: number of classes,2
#  colorDict_GRAY: colorDict_GRAY = np.array([[0],[1]])

def trainGenerator(batch_size, train_image_path, train_label_path, classNum, colorDict_GRAY):
    train_path = pd.read_csv("E:/CNN/GF2NEW/sample/sample_result/train_ano.csv")
    imageList = train_path["image"]
    labelList = train_path["label"]
    img = readTif(train_image_path + "" + imageList[0][-9:])
    while(True):
        img_generator = np.zeros((batch_size, img.shape[0], img.shape[1], img.shape[2]), np.uint16)
        label_generator = np.zeros((batch_size, img.shape[0], img.shape[1]), np.uint8)
        #  Randomly generate the starting point of a batch
        rand = random.randint(0, len(imageList) - batch_size)
        for j in range(batch_size):
            img = readTif(train_image_path + "" + imageList[rand + j][-9:])
            img_generator[j] = img
            label = cv2.imread((train_label_path + "" + labelList[rand + j][-9:]),-1).astype(np.uint8)
            label_generator[j] = label
        img_generator, label_generator = dataPreprocess(img_generator, label_generator, classNum, colorDict_GRAY)
        yield (img_generator,label_generator)

def validationGenerator(batch_size, train_image_path, train_label_path, classNum, colorDict_GRAY):
    train_path = pd.read_csv("E:/CNN/GF2NEW/sample/sample_result/validation.csv")
    imageList = train_path["image"]
    labelList = train_path["label"]
    img = readTif(train_image_path + "" + imageList[0][-9:])
    while(True):
        img_generator = np.zeros((batch_size, img.shape[0], img.shape[1], img.shape[2]), np.uint16)
        label_generator = np.zeros((batch_size, img.shape[0], img.shape[1]), np.uint8)
        #  Randomly generate the starting point of a batch
        rand = random.randint(0, len(imageList) - batch_size)
        for j in range(batch_size):
            img = readTif(train_image_path + "" + imageList[rand + j][-9:])     
            img_generator[j] = img
            label = cv2.imread((train_label_path + "" + labelList[rand + j][-9:]),-1).astype(np.uint8)
            label_generator[j] = label
        img_generator, label_generator = dataPreprocess(img_generator, label_generator, classNum, colorDict_GRAY)
        yield (img_generator,label_generator)

#### Step5: U-net Model with ResNet34 as backbone
##### https://www.kaggle.com/meaninglesslives/unet-resnet34-in-keras/notebook#Build-U-Net-Model

In [None]:
def handle_block_names(stage):
    conv_name = 'decoder_stage{}_conv'.format(stage)
    bn_name = 'decoder_stage{}_bn'.format(stage)
    relu_name = 'decoder_stage{}_relu'.format(stage)
    up_name = 'decoder_stage{}_upsample'.format(stage)
    return conv_name, bn_name, relu_name, up_name


def Upsample2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2),
                     batchnorm=False, skip=None):

    def layer(input_tensor):

        conv_name, bn_name, relu_name, up_name = handle_block_names(stage)

        x = UpSampling2D(size=upsample_rate, name=up_name)(input_tensor)

        if skip is not None:
            x = Concatenate()([x, skip])

        x = Conv2D(filters, kernel_size, padding='same', name=conv_name+'1')(x)
        if batchnorm:
            x = BatchNormalization(name=bn_name+'1')(x)
        x = Activation('relu', name=relu_name+'1')(x)

        x = Conv2D(filters, kernel_size, padding='same', name=conv_name+'2')(x)
        if batchnorm:
            x = BatchNormalization(name=bn_name+'2')(x)
        x = Activation('relu', name=relu_name+'2')(x)

        return x
    return layer


def Transpose2D_block(filters, stage, kernel_size=(3,3), upsample_rate=(2,2),
                      transpose_kernel_size=(4,4), batchnorm=False, skip=None):

    def layer(input_tensor):

        conv_name, bn_name, relu_name, up_name = handle_block_names(stage)

        x = Conv2DTranspose(filters, transpose_kernel_size, strides=upsample_rate,
                            padding='same', name=up_name)(input_tensor)
        if batchnorm:
            x = BatchNormalization(name=bn_name+'1')(x)
        x = Activation('relu', name=relu_name+'1')(x)

        if skip is not None:
            x = Concatenate()([x, skip])

        x = Conv2D(filters, kernel_size, padding='same', name=conv_name+'2')(x)
        if batchnorm:
            x = BatchNormalization(name=bn_name+'2')(x)
        x = Activation('relu', name=relu_name+'2')(x)

        return x
    return layer

def build_unet(backbone, classes, last_block_filters, skip_layers,
               n_upsample_blocks=5, upsample_rates=(2,2,2,2,2),
               block_type='upsampling', activation='sigmoid',
               **kwargs):

    input = backbone.input
    x = backbone.output

    if block_type == 'transpose':
        up_block = Transpose2D_block
    else:
        up_block = Upsample2D_block

    # convert layer names to indices
    skip_layers = ([get_layer_number(backbone, l) if isinstance(l, str) else l
                    for l in skip_layers])
    for i in range(n_upsample_blocks):

        # check if there is a skip connection
        if i < len(skip_layers):
#             print(backbone.layers[skip_layers[i]])
#             print(backbone.layers[skip_layers[i]].output)
            skip = backbone.layers[skip_layers[i]].output
        else:
            skip = None

        up_size = (upsample_rates[i], upsample_rates[i])
        filters = last_block_filters * 2**(n_upsample_blocks-(i+1))

        x = up_block(filters, i, upsample_rate=up_size, skip=skip, **kwargs)(x)

    if classes < 2:
        activation = 'sigmoid'

    x = Conv2D(classes, (3,3), padding='same', name='final_conv')(x)
    x = Activation(activation, name=activation)(x)

    model = Model(input, x)

    return model

# https://github.com/raghakot/keras-resnet/blob/master/resnet.py
def _bn_relu(input):
    """Helper to build a BN -> relu block
    """
    norm = BatchNormalization(axis=CHANNEL_AXIS)(input)
    return Activation("relu")(norm)


def _conv_bn_relu(**conv_params):
    """Helper to build a conv -> BN -> relu block
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

    def f(input):
        conv = Conv2D(filters=filters, kernel_size=kernel_size,
                      strides=strides, padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=kernel_regularizer)(input)
        return _bn_relu(conv)

    return f


def _bn_relu_conv(**conv_params):
    """Helper to build a BN -> relu -> conv block.
    This is an improved scheme proposed in http://arxiv.org/pdf/1603.05027v2.pdf
    """
    filters = conv_params["filters"]
    kernel_size = conv_params["kernel_size"]
    strides = conv_params.setdefault("strides", (1, 1))
    kernel_initializer = conv_params.setdefault("kernel_initializer", "he_normal")
    padding = conv_params.setdefault("padding", "same")
    kernel_regularizer = conv_params.setdefault("kernel_regularizer", l2(1.e-4))

    def f(input):
        activation = _bn_relu(input)
        return Conv2D(filters=filters, kernel_size=kernel_size,
                      strides=strides, padding=padding,
                      kernel_initializer=kernel_initializer,
                      kernel_regularizer=kernel_regularizer)(activation)

    return f


def _shortcut(input, residual):
    """Adds a shortcut between input and residual block and merges them with "sum"
    """
    # Expand channels of shortcut to match residual.
    # Stride appropriately to match residual (width, height)
    # Should be int if network architecture is correctly configured.
    input_shape = K.int_shape(input)
    residual_shape = K.int_shape(residual)
    stride_width = int(round(input_shape[ROW_AXIS] / residual_shape[ROW_AXIS]))
    stride_height = int(round(input_shape[COL_AXIS] / residual_shape[COL_AXIS]))
    equal_channels = input_shape[CHANNEL_AXIS] == residual_shape[CHANNEL_AXIS]

    shortcut = input
    # 1 X 1 conv if shape is different. Else identity.
    if stride_width > 1 or stride_height > 1 or not equal_channels:
        shortcut = Conv2D(filters=residual_shape[CHANNEL_AXIS],
                          kernel_size=(1, 1),
                          strides=(stride_width, stride_height),
                          padding="valid",
                          kernel_initializer="he_normal",
                          kernel_regularizer=l2(0.0001))(input)

    return add([shortcut, residual])

def basic_block(filters, init_strides=(1, 1), is_first_block_of_first_layer=False):
    """Basic 3 X 3 convolution blocks for use on resnets with layers <= 34.
    """
    def f(input):

        if is_first_block_of_first_layer:
            # don't repeat bn->relu since we just did bn->relu->maxpool
            conv1 = Conv2D(filters=filters, kernel_size=(3, 3),
                           strides=init_strides,
                           padding="same",
                           kernel_initializer="he_normal",
                           kernel_regularizer=l2(1e-4))(input)
        else:
            conv1 = _bn_relu_conv(filters=filters, kernel_size=(3, 3),
                                  strides=init_strides)(input)

        residual = _bn_relu_conv(filters=filters, kernel_size=(3, 3))(conv1)
        return _shortcut(input, residual)

    return f

def _residual_block(block_function, filters, repetitions, is_first_layer=False):
    """Builds a residual block with repeating bottleneck blocks.
    """
    def f(input):
        for i in range(repetitions):
            init_strides = (1, 1)
            if i == 0 and not is_first_layer:
                init_strides = (2, 2)
            input = block_function(filters=filters, init_strides=init_strides,
                                   is_first_block_of_first_layer=(is_first_layer and i == 0))(input)
        return input

    return f

def _handle_dim_ordering():
    global ROW_AXIS
    global COL_AXIS
    global CHANNEL_AXIS
    if K.image_dim_ordering() == 'tf':
        ROW_AXIS = 1
        COL_AXIS = 2
        CHANNEL_AXIS = 3
    else:
        CHANNEL_AXIS = 1
        ROW_AXIS = 2
        COL_AXIS = 3


def _get_block(identifier):
    if isinstance(identifier, six.string_types):
        res = globals().get(identifier)
        if not res:
            raise ValueError('Invalid {}'.format(identifier))
        return res
    return identifier


class ResnetBuilder(object):
    @staticmethod
    def build(input_shape, block_fn, repetitions,input_tensor):
        _handle_dim_ordering()
        if len(input_shape) != 3:
            raise Exception("Input shape should be a tuple (nb_channels, nb_rows, nb_cols)")

        # Permute dimension order if necessary
        if K.image_dim_ordering() == 'tf':
            input_shape = (input_shape[1], input_shape[2], input_shape[0])

        # Load function from str if needed.
        block_fn = _get_block(block_fn)
        
        if input_tensor is None:
            img_input = Input(shape=input_shape)
        else:
            if not K.is_keras_tensor(input_tensor):
                img_input = Input(tensor=input_tensor, shape=input_shape)
            else:
                img_input = input_tensor
                
        conv1 = _conv_bn_relu(filters=64, kernel_size=(7, 7), strides=(2, 2))(img_input)
        pool1 = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(conv1)

        block = pool1
        filters = 64
        for i, r in enumerate(repetitions):
            block = _residual_block(block_fn, filters=filters, repetitions=r, is_first_layer=(i == 0))(block)
            filters *= 2

        # Last activation
        block = _bn_relu(block)

        model = Model(inputs=img_input, outputs=block)
        return model

    @staticmethod
    def build_resnet_34(input_shape,input_tensor):
        return ResnetBuilder.build(input_shape, basic_block, [3, 4, 6, 3],input_tensor)

In [None]:
def UResNet34(input_shape=(None, None, 4), classes=2, decoder_filters=16, decoder_block_type='upsampling',
                       encoder_weights=None, input_tensor=None, activation='sigmoid', **kwargs):

    backbone = ResnetBuilder.build_resnet_34(input_shape=input_shape,input_tensor=input_tensor)

    skip_connections = list([97,54,25])  # for resnet 34
    model = build_unet(backbone, classes, decoder_filters,
                       skip_connections, block_type=decoder_block_type,
                       activation=activation, **kwargs)
    model.name = 'u-resnet34'

    return model

#### Step5: losses function
##### the loss function used during training was the focal loss, which deals better with class imbalance in segmentation compared to the common classification loss

In [None]:
# focal losses function
from keras.losses import binary_crossentropy
from keras import backend as K
def binary_focal_loss(gamma=2, alpha=0.25):
    """
    Binary form of focal loss.  
    focal_loss(p_t) = -alpha_t * (1 - p_t)**gamma * log(p_t)
        where p = sigmoid(x), p_t = p or 1 - p depending on if the label is 1 or 0, respectively.
    References:
        https://arxiv.org/pdf/1708.02002.pdf
    Usage:
     model.compile(loss=[binary_focal_loss(alpha=.25, gamma=2)], metrics=["accuracy"], optimizer=adam)
    """
    alpha = tf.constant(alpha, dtype=tf.float32)
    gamma = tf.constant(gamma, dtype=tf.float32)

    def binary_focal_loss_fixed(y_true, y_pred):
        """
        y_true shape need be (None,1)
        y_pred need be compute after sigmoid
        """
        y_true = tf.cast(y_true, tf.float32)
        alpha_t = y_true*alpha + (K.ones_like(y_true)-y_true)*(1-alpha)
    
        p_t = y_true*y_pred + (K.ones_like(y_true)-y_true)*(K.ones_like(y_true)-y_pred) + K.epsilon()
        focal_loss = - alpha_t * K.pow((K.ones_like(y_true)-p_t),gamma) * K.log(p_t)
        return K.mean(focal_loss)
    return binary_focal_loss_fixed

#### Step6: Model traning
##### parameters

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
colorDict_GRAY = np.array([[0],[1]])


# parameters of dataset
train_image_path = "E:/CNN/GF2NEW/sample/sample_result/image/"

train_label_path = "E:/CNN/GF2NEW/sample/sample_result/label/"

validation_image_path = "E:/CNN/GF2NEW/sample/sample_result/image_val/"

validation_label_path = "E:/CNN/GF2NEW/sample/sample_result/label_val/"

# parameters of model
batch_size = 8

classNum = 2
   
# epochs
epochs = 100

# learning rate
learning_rate = 1e-3
    
# save of model
model_path = "E:/CNN/GF2NEW/sample/sample_result/unet_model_res34.hdf5"

# number of training set
train_path = pd.read_csv("E:/CNN/GF2NEW/sample/sample_result/train_ano.csv")
imageList1 = train_path["image"]
train_num = len(imageList1)
print (train_num)

# number of validation set
train_path = pd.read_csv("E:/CNN/GF2NEW/sample/sample_result/validation.csv")
imageList = train_path["image"]
validation_num = len(imageList)
print (validation_num)

steps_per_epoch = train_num / batch_size
validation_steps = validation_num / batch_size

#  training and validation dataset with rate of batch_size
train_data = trainGenerator(batch_size,
                    train_image_path, 
                    train_label_path,
                    classNum,
                    colorDict_GRAY)

validation_data = validationGenerator(batch_size,
                      validation_image_path,
                      validation_label_path,
                      classNum,
                      colorDict_GRAY)

with tf.device("/gpu:0"):
    model = UResNet34(input_shape=(4, 128,128))
    model.summary()
    model.compile(optimizer = "adam", loss = [binary_focal_loss(alpha=.25, gamma=2)], metrics = ["accuracy"])
    early_stopping = EarlyStopping(patience=10, verbose=1)
    model_checkpoint = ModelCheckpoint(model_path,
                      monitor = 'loss',
                      verbose = 1,
                      mode='min',
                      save_best_only = True)
    reduce_lr = ReduceLROnPlateau(factor=0.1, patience=4, min_lr=0.00001, verbose=1)
    start_time = datetime.datetime.now()
    history = model.fit_generator(train_data,
                        steps_per_epoch = steps_per_epoch,
                        epochs = epochs,
                        callbacks=[early_stopping, model_checkpoint, reduce_lr],
                        validation_data = validation_data,
                        validation_steps = validation_steps)
    
    end_time = datetime.datetime.now()
    log_time = "training time: " + str((end_time - start_time).seconds / 60) + "m"
    print(log_time)
    with open('E:/CNN/GF2NEW/sample/sample_result/TrainTime.txt','w') as f:
        f.write(log_time)

#### Step7: Model predictions
##### If the large remote sensing images to be classified are directly input into the network model, it will cause memory overflow, so the images to be classified are generally cropped into a series of smaller images and input into the network for prediction, and then the prediction results are stitched into one final image in the cropping order.

In [None]:
def TifCroppingArray(img, SideLength):

    TifArrayReturn = []
    ColumnNum = int((img.shape[0] - SideLength * 2) / (128 - SideLength * 2))
    RowNum = int((img.shape[1] - SideLength * 2) / (128 - SideLength * 2))
    for i in range(ColumnNum):
        TifArray = []
        for j in range(RowNum):
            cropped = img[i * (128 - SideLength * 2) : i * (128 - SideLength * 2) + 128,
                          j * (128 - SideLength * 2) : j * (128 - SideLength * 2) + 128]
            TifArray.append(cropped)
        TifArrayReturn.append(TifArray)
    
    for i in range(ColumnNum):
        cropped = img[i * (128 - SideLength * 2) : i * (128 - SideLength * 2) + 128, 
                      (img.shape[1] - 128) : img.shape[1]]
        TifArrayReturn[i].append(cropped)
    TifArray = []
    for j in range(RowNum):
        cropped = img[(img.shape[0] - 128) : img.shape[0],
                      j * (128-SideLength*2) : j * (128 - SideLength * 2) + 128]
        TifArray.append(cropped)

    cropped = img[(img.shape[0] - 128) : img.shape[0],
                  (img.shape[1] - 128) : img.shape[1]]
    TifArray.append(cropped)
    TifArrayReturn.append(TifArray)
    ColumnOver = (img.shape[0] - SideLength * 2) % (128 - SideLength * 2) + SideLength
    RowOver = (img.shape[1] - SideLength * 2) % (128 - SideLength * 2) + SideLength
    return TifArrayReturn, RowOver, ColumnOver

def labelVisualize(img):
    img_out = np.zeros((img.shape[0],img.shape[1]))
    for i in range(img.shape[0]):
        for j in range(img.shape[1]):
            img_out[i][j] = np.argmax(img[i][j])
    return img_out


def testGenerator(TifArray):
    for i in range(len(TifArray)):
        for j in range(len(TifArray[0])):
            img = TifArray[i][j]
            img = np.reshape(img,(1,)+img.shape)
            yield img


def Result(shape, TifArray, npyfile, num_class, RepetitiveLength, RowOver, ColumnOver):
    result = np.zeros(shape, np.uint8)
    j = 0  
    for i,item in enumerate(npyfile):
        img = labelVisualize(item)
        img = img.astype(np.uint8)
        if(i % len(TifArray[0]) == 0):
            if(j == 0):
                result[0 : 128 - RepetitiveLength, 0 : 128-RepetitiveLength] = img[0 : 128 - RepetitiveLength, 0 : 128 - RepetitiveLength]
            elif(j == len(TifArray) - 1):
                result[shape[0] - ColumnOver - RepetitiveLength: shape[0], 0 : 128 - RepetitiveLength] = img[128 - ColumnOver - RepetitiveLength : 128, 0 : 128 - RepetitiveLength]
            else:
                result[j * (128 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength,
                       0:128-RepetitiveLength] = img[RepetitiveLength : 128 - RepetitiveLength, 0 : 128 - RepetitiveLength]   
       
        elif(i % len(TifArray[0]) == len(TifArray[0]) - 1):
            if(j == 0):
                result[0 : 128 - RepetitiveLength, shape[1] - RowOver: shape[1]] = img[0 : 128 - RepetitiveLength, 128 -  RowOver: 128]
            elif(j == len(TifArray) - 1):
                result[shape[0] - ColumnOver : shape[0], shape[1] - RowOver : shape[1]] = img[128 - ColumnOver : 128, 128 - RowOver : 128]
            else:
                result[j * (128 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength,
                       shape[1] - RowOver : shape[1]] = img[RepetitiveLength : 128 - RepetitiveLength, 128 - RowOver : 128]   
            j = j + 1
        else:
            if(j == 0):
                result[0 : 128 - RepetitiveLength,
                       (i - j * len(TifArray[0])) * (128 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength
                       ] = img[0 : 128 - RepetitiveLength, RepetitiveLength : 256 - RepetitiveLength]         
            if(j == len(TifArray) - 1):
                result[shape[0] - ColumnOver : shape[0],
                       (i - j * len(TifArray[0])) * (128 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength
                       ] = img[128 - ColumnOver : 128, RepetitiveLength : 128 - RepetitiveLength]
            else:
                result[j * (128 - 2 * RepetitiveLength) + RepetitiveLength : (j + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength,
                       (i - j * len(TifArray[0])) * (128 - 2 * RepetitiveLength) + RepetitiveLength : (i - j * len(TifArray[0]) + 1) * (128 - 2 * RepetitiveLength) + RepetitiveLength,
                       ] = img[RepetitiveLength : 128 - RepetitiveLength, RepetitiveLength : 128 - RepetitiveLength]
    return result

area_perc = 0.2
TifPath = r"E:/CNN/GF2NEW/sample/gf2image/GF2_PMS1_E103.1_N39.0_20191007_L1A0004292629_gs.tif"
ModelPath = r"E:/CNN/GF2NEW/sample/sample_result/unet_model_res34.hdf5"
ResultPath = r"E:/CNN/GF2NEW/sample/gf2image/unet/GF2_PMS1_E103.1_N39.0_20191007_L1A0004292629_unet.tif"
RepetitiveLength = int((1 - math.sqrt(area_perc)) * 128 / 2)


testtime = []
starttime = datetime.datetime.now()

im_width, im_height, im_bands, im_data, im_geotrans, im_proj = readTif(TifPath)
im_data = im_data.swapaxes(1, 0)
im_data = im_data.swapaxes(1, 2)

TifArray, RowOver, ColumnOver = TifCroppingArray(im_data, RepetitiveLength)
endtime = datetime.datetime.now()
text = "finished tiffread with time: " + str((endtime - starttime).seconds) + "s"
print(text)
testtime.append(text)

model = load_model(ModelPath)
testGene = testGenerator(TifArray)
results = model.predict_generator(testGene,
                                  len(TifArray) * len(TifArray[0]),
                                  verbose = 1)
endtime = datetime.datetime.now()
text = "finished prediction with time: " + str((endtime - starttime).seconds) + "s"
print(text)
testtime.append(text)


result_shape = (im_data.shape[0], im_data.shape[1])
result_data = Result(result_shape, TifArray, results, 2, RepetitiveLength, RowOver, ColumnOver)
writeTiff(result_data, im_geotrans, im_proj, ResultPath)
endtime = datetime.datetime.now()
text = "save result with time: " + str((endtime - starttime).seconds) + "s"
print(text)
testtime.append(text)

time = datetime.datetime.strftime(datetime.datetime.now(), '%Y%m%d-%H%M%S')
with open('timelog_%s.txt'%time, 'w') as f:
    for i in range(len(testtime)):
        f.write(testtime[i])
        f.write("\r\n")

#### Step8: Testing

In [1]:
def ConfusionMatrix(numClass, imgPredict, Label):  
    mask = (Label >= 0) & (Label < numClass)  
    label = numClass * Label[mask] + imgPredict[mask]  
    count = np.bincount(label, minlength = numClass**2)  
    confusionMatrix = count.reshape(numClass, numClass)  
    return confusionMatrix

def OverallAccuracy(confusionMatrix):  
    OA = np.diag(confusionMatrix).sum() / confusionMatrix.sum()  
    return OA
  
def Precision(confusionMatrix):  
    precision = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 0)
    return precision  

def Recall(confusionMatrix):
    recall = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 1)
    return recall
  
def F1Score(confusionMatrix):
    precision = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 0)
    recall = np.diag(confusionMatrix) / confusionMatrix.sum(axis = 1)
    f1score = 2 * precision * recall / (precision + recall)
    return f1score

def IntersectionOverUnion(confusionMatrix):  
    intersection = np.diag(confusionMatrix)  
    union = np.sum(confusionMatrix, axis = 1) + np.sum(confusionMatrix, axis = 0) - np.diag(confusionMatrix)  
    IoU = intersection / union
    return IoU

#### Step9: post processesing
##### convert from raster to shapefile, and delete larger and smaller pathch in arcpy

In [None]:
import arcpy
from arcpy import env
from arcpy.sa import *

arcpy.env.overwriteOutput = True
env.workspace = "E:/CNN/GF2NEW/sample/gf2image/unet/"
bcn = arcpy.ListRasters("*","tif")

for i in bcn:
    # dat to shp
    shp_name = "E:/CNN/GF2NEW/sample/gf2image/unet_shp/%s.shp"%i[:-4]
    arcpy.RasterToPolygon_conversion(i, shp_name, "SIMPLIFY", "Value")
    pro_name = "E:/CNN/GF2NEW/sample/gf2image/unet_shp/%s_pro.shp"%i[:-4]
    
    outCS = arcpy.SpatialReference('WGS 1984 UTM Zone 48N')

    arcpy.Project_management(shp_name, pro_name, outCS)
    arcpy.MakeFeatureLayer_management(pro_name, "lyr")
    arcpy.SelectLayerByAttribute_management("lyr", "NEW_SELECTION", "(GRIDCODE = 1) & (Shape_Area>1.8) & (Shape_Area<800) & (Shape_Leng<80)")

    out_path = "E:/CNN/GF2NEW/sample/gf2image/unet_post/%s_post.shp"%i[:-4]
    arcpy.CopyFeatures_management("lyr", out_path)
    arcpy.Delete_management("lyr")

In [None]:
from pathlib import Path
try:
    from osgeo import gdal
    from osgeo import ogr
    from osgeo import osr
except ImportError:
    import gdal
    import ogr
    import osr

def compute_metrics(inShpPath):
    driver = ogr.GetDriverByName("ESRI Shapefile")
    dataSource = driver.Open(inShpPath, 1)
    layer = dataSource.GetLayer()
    
    new_field = ogr.FieldDefn("Length", ogr.OFTReal)
    new_field.SetWidth(32)
    new_field.SetPrecision(2)  
    layer.CreateField(new_field)

    new_field1 = ogr.FieldDefn("Area", ogr.OFTReal)
    new_field1.SetWidth(32)
    new_field1.SetPrecision(16)
    layer.CreateField(new_field1)

    new_field2 = ogr.FieldDefn("X", ogr.OFTReal)
    new_field2.SetWidth(32)
    new_field2.SetPrecision(16)
    layer.CreateField(new_field2)

    new_field3 = ogr.FieldDefn("Y", ogr.OFTReal)
    new_field3.SetWidth(32)
    new_field3.SetPrecision(16)
    layer.CreateField(new_field3)

    for feature in layer:

        geom = feature.GetGeometryRef()
        geom2 = geom.Clone()
        geom2.Transform(transform)

        xmin, xmax, ymin, ymax = geom2.GetEnvelope()
        x = (xmin + xmax) / 2
        y = (ymin + ymax) / 2
        area_in_sq_m = geom2.GetArea()
        perimeter = geom.Boundary().Length() 

        feature.SetField("Length", perimeter)
        layer.SetFeature(feature)
        
        feature.SetField("Area", area_in_sq_m)
        layer.SetFeature(feature)

        feature.SetField("X", x)
        layer.SetFeature(feature)

        feature.SetField("Y", y)
        layer.SetFeature(feature)

    del dataSource

arcpy.env.overwriteOutput = True
env.workspace = "E:/CNN/GF2NEW/sample/gf2image/unet_post/"
Fealist = arcpy.ListFeatureClasses()
for i in Fealist:
    compute_metrics(i)
print ("finish")

In [None]:
import arcpy,os
def creat_point(in_polygon):
    featuresList = []
    polygon_fields = ['SHAPE@TRUECENTROID', 'Length', 'Area', 'X','Y']
    with arcpy.da.SearchCursor(in_polygon, polygon_fields) as cursor:
        for row in cursor:
            featuresList.append([row[0], row[1], row[2], row[3],row[4]])

    fc_name = 'Point_%s'%in_polygon[-7:]
    output_location = r'E:/CNN/GF2NEW/sample/gf2image/point/'
    sr = arcpy.SpatialReference('WGS 1984 UTM Zone 48N')
 
    arcpy.CreateFeatureclass_management(output_location,fc_name,'POINT',spatial_reference=sr)
    point_layer = os.path.join(output_location,fc_name)
    arcpy.AddField_management(point_layer,'Length','DOUBLE')
    arcpy.AddField_management(point_layer,'Area','DOUBLE')
    arcpy.AddField_management(point_layer,'X','DOUBLE')
    arcpy.AddField_management(point_layer,'Y','DOUBLE')
    point_fields = ['SHAPE@XY', 'Length', 'Area', 'X','Y']
    with arcpy.da.InsertCursor(point_layer, point_fields) as cursor:
        for record in featuresList:
            cursor.insertRow(record)

arcpy.env.overwriteOutput = True
env.workspace = "E:/CNN/GF2NEW/sample/gf2image/unet_post/"
Fealist = arcpy.ListFeatureClasses()
for i in Fealist:
    creat_point(i)
print ("finish")

#### Step10: Statistics and Analysis
##### These datasets are derived from DBF files of point shapefile, which have achieved TWI value and SOM value from estimated images. The reason for generated point shapefile is that the spatial resolution of these images is 30 and 16 meters, respectively. 

In [None]:
# Extracts the cells of multiple rasters as attributes in an output point feature class.  

import arcpy
from arcpy import env
from arcpy.sa import *

def Ex_multi_value(featurefile):
  
    # Set local variables
    inPointFeatures = featurefile
    inRasterList = [["E:/CNN/GF2NEW/sample/twi.tif", "RASTERVALU"], ["E:/CNN/GF2NEW/sample/som.tif", "som"],
                    ["E:/CNN/GF2NEW/sample/sma_pv.tif", "pv"],["E:/CNN/GF2NEW/sample/sma_npv.tif", "npv"],
                    ["E:/CNN/GF2NEW/sample/sma_bs.tif", "bs"]]

    # Check out the ArcGIS Spatial Analyst extension license
    arcpy.CheckOutExtension("Spatial")

    # Execute ExtractValuesToPoints
    ExtractMultiValuesToPoints(inPointFeatures, inRasterList, "BILINEAR")
    
# environment settings
env.workspace = "E:/CNN/GF2NEW/sample/gf2image/point/"
Fealist = arcpy.ListFeatureClasses()
for i in Fealist:
    Ex_multi_value(i)

In [None]:
# Multi-value of shrub islands read as dataframe.
import arcpy
import os
import dbfread
from pandas import DataFrame
input_path = "E:/CNN/GF2NEW/sample/gf2image/point/"
filenames = os.listdir(input_path)
df_list = []
for i in filenames:
    if os.path.splitext(i)[1] == '.dbf':
        inDBF = path + filename
        table = dbfread.DBF(inDBF, encoding='GBK',load=True)
        dftable = DataFrame(iter(table))
        df_list.append(dftable)
frame_twi = pd.concat(df_list,axis=0,ignore_index=True)
print(frame_twi.head())

In [None]:
print(len(frame_twi))

##### Figure of shurb island features

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator, FormatStrFormatter
fig,axs = plt.subplots(nrows = 2, ncols = 2,figsize = (7, 6))
plt.rc('font',family='Times New Roman',size = 10)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

plt.subplot(2,2,1)
data = [x,y]
f = plt.boxplot(data,showmeans=True,widths = 0.2, showfliers = False, patch_artist = True,boxprops = {'color':'black'},medianprops= {'color':'black','linewidth':3},meanprops={'marker':'s',"markersize":3})
for box in f['boxes']:   
    box.set(color='black', linewidth=1)
    box.set(facecolor='lightgray')
for whisker in f['whiskers']:
    whisker.set(color='black', linewidth=1)
for cap in f['caps']:
    cap.set(color='black', linewidth=2)
for median in f['medians']:
    median.set(color='black', linewidth=1)
for flier in f['fliers']:
    flier.set(marker='o', color='y', alpha=0.5)
violin_parts = plt.violinplot(data,widths = 0.7,showmeans=False,showmedians=False,showextrema=False)
for pc in violin_parts['bodies']:
    pc.set(facecolor='none')
    pc.set(edgecolor='black',linewidth=2)
plt.xlim(0.5,2.5)
plt.ylim([0,200])
# plt.yticks([0,0.02,0.04,0.06,0.08,0.10,0.12,0.14,0.16,0.18,0.20],fontsize = 9)
plt.xticks([1,2],["Zone 1","Zone 2"])
plt.xlabel('Zones')
plt.ylabel('Size')

import pandas as pd
plt.subplot(2,2,2)
data1 = pd.read_csv('E:/CNN/Figure/oasis.csv')
dis = data1['D']
Z1 = data1["Z1"]
Z2 = data1["Z2"]
plt.plot(dis,Z1,label='Zone 1')
plt.plot(dis,Z2,label='Zone 2')
plt.xlim(0,32)
plt.xticks([0,8,16,24,32])
plt.xlabel('Distance (km^2)')
plt.ylim(500,2000)
plt.yticks([500,1000,1500,2000])
plt.ylabel('Density')
plt.legend()
plt.subplot(2,2,3)
data2 = pd.read_csv('E:/CNN/Figure/rain.csv')
rain = data2['RA']
mean = data2["MEAN"]
plt.bar(rain,mean,width=8,color= "#8064a2")
plt.xlim(90,210)
plt.xticks([90,120,150,180,210])
plt.xlabel('Precipitation (mm)')
plt.ylim(700,1500)
plt.yticks([700,900,1100,1300,1500])
plt.ylabel('Density')

plt.subplot(2,2,4)
data3 = pd.read_csv('E:/CNN/Figure/twi.csv')
twi = data3['twi']
me = data3["mean"]
plt.bar(twi,me,width=4,color= "#4bacc6" )
plt.xlim(0,40)
plt.xticks([0,10,20,30,40])
plt.xlabel('TWI')
plt.ylim(800,3200)
plt.yticks([800,1600,2400,3200])
plt.ylabel('Density')
plt.subplots_adjust(left = 0.10, bottom=0.10, right=0.90, top=0.90,wspace= 0.3, hspace = 0.3)

plt.savefig(fname= "E:/CNN/Figure/STA.pdf", dpi=600)
plt.show()

In [None]:
# Figure of fraction
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

fc_twi = frame_twi['pv','npv','bs']

fig,axs = plt.subplots(nrows = 1, ncols = 1,figsize = (3.5, 3))
plt.rc('font',family='Arial',size = 9)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'
plt.subplot(1,1,1)
f = plt.boxplot(fc_twi,showmeans=True,widths = 0.2, showfliers = False, patch_artist = True,boxprops = {'color':'black'},medianprops= {'color':'black','linewidth':3},meanprops={'marker':'s',"markersize":3})
for box in f['boxes']:
    box.set(color='black', linewidth=1)
    box.set(facecolor='lightgray')
for whisker in f['whiskers']:
    whisker.set(color='black', linewidth=1)
for cap in f['caps']:
    cap.set(color='black', linewidth=2)
for median in f['medians']:
    median.set(color='black', linewidth=1)
for flier in f['fliers']:
    flier.set(marker='o', color='y', alpha=0.5)
violin_parts = plt.violinplot(data_twi,widths = 0.8,showmeans=False,showmedians=False,showextrema=False)
for pc in violin_parts['bodies']:
    pc.set(facecolor='none')
    pc.set(edgecolor='black',linewidth=2)
plt.xlim(0.5,3.5)
plt.ylim([0,10000])
plt.yticks([0,2000,4000,6000,8000,10000],[0,20,40,60,80,100])
plt.xticks([1,2,3],["PV","NPV","BS"])
plt.xlabel('Endmembers')
plt.ylabel('Fractions (%)')
plt.subplots_adjust(left = 0.20, bottom=0.20, right=0.95, top=0.95)
plt.savefig(fname= "E:/CNN/Figure/em_fraction.pdf", dpi=600)
plt.show()

##### Ecosystem productivity represented by fractional non-photosynthetic vegetation (NPV) and photosynthetic vegetation (PV) along a gradient of topographic wetness index (TWI)

In [None]:
# Figure 1a TWI vs pv+npv(%)
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from pylab import *
import pandas as pd
from matplotlib.colors import ListedColormap,LinearSegmentedColormap

frame_twi['fv'] =  frame_twi['pv']+frame_twi['npv']

x = frame_twi['fv']
y = frame_twi['RASTERVALU']

yr = round(y)
xr = np.range(1,40,1)
yr_med = []
for i in range(1,40,1):
    yr_list = []
    for j in range(len(yr))
        if yr[j]==i:
            yr_list.append(yr[j]) 
    if len(yr_list)>0:
        yr_med.append(np.median(yr_list))
    else:
        yr_med.append(np.nan)


clist=['#a50026','#d73027','#f46d43','#fdae61','#fee090','#ffffff','#e0f3f8','#abd9e9','#74add1','#4575b4','#313695']
newcmp = LinearSegmentedColormap.from_list('ggp',clist)
cm.register_cmap(cmap=newcmp)
#------------------------------------Figure---------------------------------------
fig,axs = plt.subplots(nrows = 1, ncols = 1,figsize = (3.5, 3))
extent = (0,1,0,1)
plt.rc('font',family='Arial',size = 9)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

plt.subplot(1,1,1)
h = plt.hist2d(x, y, 100, cmap=cm.get_cmap('ggp'),cmin = 1)#density = True
plt.plot(xr,yr_med,color="black",linewidth=2)

plt.xlim(0,40)
plt.ylim(0,10000)
plt.xticks([0,5,10,15,20,25,30,35],[0,5,10,15,20,25,30,35])
plt.yticks([0,2000,4000,6000,8000,10000],[0,20,40,60,80,100])
plt.xlabel("TWI")
plt.ylabel("Ecosystem productivity (NPV+PV, %)")
plt.subplots_adjust(left = 0.15, bottom=0.15, right=0.95, top=0.95)
plt.savefig(fname= "E:/CNN/Figure/twi_v3.pdf", dpi=600)
plt.show()

In [None]:
# Figure 1b
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde
from matplotlib.colors import LogNorm

import pandas as pd
from pylab import *
from matplotlib.colors import ListedColormap,LinearSegmentedColormap


fig,axs = plt.subplots(nrows = 1, ncols = 1,figsize = (3.6, 3))
#figsize = (width，hight)
extent = (0,1,0,1)
plt.rc('font',family='Arial',size = 9)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

plt.subplot(2,1,1)
data_npv_pv = pd.read_csv('E:/CNN/Figure/pv_npv.csv')
twi_x = data_npv_pv["twi"]
twi_y1 = data_npv_pv["pv"]
twi_y2 = data_npv_pv["npv"]
twi_y3 = data_npv_pv["bz"]
width = 0.4
plt.bar(twi_x-0.2,twi_y1,width= width,color = "forestgreen",label = "PV")
plt.bar(twi_x+0.2,twi_y2,width= width,color = "goldenrod",label = "NPV")
plt.xlim(0,40)
plt.ylim(0,8000)
plt.xticks([0,5,10,15,20,25,30,35],[])
plt.yticks([0,2000,4000,6000,8000],[0,20,40,60,80])
plt.legend(frameon = False,labelspacing=0.2,borderpad=0.15)
plt.ylabel("Fractions (%)")
plt.subplot(2,1,2)
plt.scatter(twi_x,twi_y3,s=10,c = "gray")
twi_y4 = data_npv_pv["nh"]
twi_y5 = data_npv_pv["nh2"]
twi_y6 = data_npv_pv["nh3"]
plt.plot(twi_x,twi_y6,color = "red",label = "fitting")
plt.plot(twi_x,twi_y4,color = "brown",label = "fitting")
plt.plot(twi_x,twi_y5,color = "coral",label = "fitting")
plt.xlim(0,40)
plt.ylim(0.1,0.3)
plt.xticks([0,5,10,15,20,25,30,35],[0,5,10,15,20,25,30,35])
plt.yticks([0.1,0.2,0.3],[0.1,0.2,0.3])
plt.xlabel("TWI")
plt.ylabel("PV/(NPV+PV)")
plt.subplots_adjust(left = 0.15, bottom=0.15, right=0.95, top=0.95)
plt.savefig(fname= "E:/CNN/Figure/twi_v_3.pdf", dpi=600)

##### Power law fitting of probability densities of shrub island sizes

In [None]:
frame_twi12 = frame_twi[(frame_twi["RASTERVALU"]>2)&(frame_twi["RASTERVALU"]<=4)]
frame_twi13 = frame_twi[(frame_twi["RASTERVALU"]>4)&(frame_twi["RASTERVALU"]<=6)]
frame_twi14 = frame_twi[(frame_twi["RASTERVALU"]>6)&(frame_twi["RASTERVALU"]<=8)]
frame_twi1 = frame_twi[(frame_twi["RASTERVALU"]>0)&(frame_twi["RASTERVALU"]<=8)]
print (frame_twi1.head())

frame_twi21 = frame_twi[(frame_twi["RASTERVALU"]>8)&(frame_twi["RASTERVALU"]<=10)]
frame_twi22 = frame_twi[(frame_twi["RASTERVALU"]>10)&(frame_twi["RASTERVALU"]<=12)]
frame_twi23 = frame_twi[(frame_twi["RASTERVALU"]>12)&(frame_twi["RASTERVALU"]<=14)]
frame_twi24 = frame_twi[(frame_twi["RASTERVALU"]>14)&(frame_twi["RASTERVALU"]<=16)]
frame_twi2 = frame_twi[(frame_twi["RASTERVALU"]>8)&(frame_twi["RASTERVALU"]<=16)]

frame_twi31 = frame_twi[(frame_twi["RASTERVALU"]>16)&(frame_twi["RASTERVALU"]<=19)]
frame_twi32 = frame_twi[(frame_twi["RASTERVALU"]>19)&(frame_twi["RASTERVALU"]<=22)]
frame_twi33 = frame_twi[(frame_twi["RASTERVALU"]>22)&(frame_twi["RASTERVALU"]<=25)]
frame_twi3 = frame_twi[(frame_twi["RASTERVALU"]>16)&(frame_twi["RASTERVALU"]<=25)]

frame_twi4 = frame_twi[(frame_twi["RASTERVALU"]>25)]

In [None]:
def pl(x):
    p = np.array(round(x['Area'],2))
    p = p[~np.isnan(p)]
    results = powerlaw.Fit(p,xmin = 1.0, discrete=True)
    results_no_xmin = powerlaw.Fit(p,discrete=True)
    aph = results.power_law.alpha
    xmin_noxmin = results_no_xmin.power_law.xmin
    aph_noxmin = results_no_xmin.power_law.alpha
    R, p = results.distribution_compare('power_law', 'truncated_power_law')
    R_noxmin, p_noxmin = results_no_xmin.distribution_compare('power_law', 'truncated_power_law')
    plr = 1-(np.log(xmin_noxmin)-np.log(np.min(p)))/(np.log(np.max(p))-np.log(np.min(p)))
    return [results,results_no_xmin,aph,aph_noxmin,xmin_noxmin,R, p,R_noxmin, p_noxmin,plr]

In [None]:
pl_twi = [frame_twi12,frame_twi13,frame_twi14,frame_twi1,frame_twi21,frame_twi22,
          frame_twi23,frame_twi24,frame_twi2,frame_twi31,frame_twi32,frame_twi33,frame_twi3,frame_twi4]

plr_list = []
R_list = []
R_noxmin_list = []
for i in range(len(pl_twi)):
    print (i,pl(pl_twi[i])[2:])
    if 0<=i<3 or 4<=i<8 or 9<=i:
        plr_list.append(pl(pl_twi[i])[-1])
        R_list.append(pl(pl_twi[i])[-5])
        R_noxmin_list.append(pl(pl_twi[i])[-3])

In [None]:
# Figure 2
fig,axs = plt.subplots(nrows = 3, ncols = 5,figsize = (7, 3))
#figsize = (width，hight)
extent = (0,1,0,1)
plt.rc('font',family='Arial',size = 9)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

for i in range(len(pl_twi)):
    results = pl(pl_twi[i])[0]
    results_no_xmin = pl(pl_twi[i])[1]
    if i<4:
        plt.subplot(3,5,i+1)
        fig=results.plot_pdf(linewidth=2)
        results.power_law.plot_pdf(ax=fig,color='blue',linestyle='--')
        results.truncated_power_law.plot_pdf(ax=fig,color='green',linestyle='--')
        results_no_xmin.power_law.plot_pdf(ax=fig,color='red',linestyle='--')
        plt.ylim(0.0000001,1.0)
        plt.xmin(1,1000)
        plt.xticks([1,10,100,1000], [])
        if i==0:
            plt.yticks([0.0000001,0.0001,0.1], [0.0000001,0.0001,0.1]) 
        else:
            plt.yticks([0.0000001,0.0001,0.1], []) 
    else:
        plt.subplot(3,5,i+2)
        fig=results.plot_pdf(linewidth=2)
        results.power_law.plot_pdf(ax=fig,color='blue',linestyle='--')
        results.truncated_power_law.plot_pdf(ax=fig,color='green',linestyle='--')
        results_no_xmin.power_law.plot_pdf(ax=fig,color='red',linestyle='--')
        plt.ylim(0.000001,1.0)
        plt.xmin(1,1000)
        if i==4 or i==9:
            plt.yticks([0.0000001,0.0001,0.1], [0.0000001,0.0001,0.1]) 
        if i>=9:
            plt.xticks([1,10,100,1000], [1,10,100,1000])
        else:
            plt.xticks([1,10,100,1000], [])
            plt.yticks([0.0000001,0.0001,0.1], [])
plt.savefig(fname= "E:/CNN/Figure/twi_size.pdf", dpi=600)    
plt.show()

In [None]:
# Figure 3
fig,axs = plt.subplots(nrows = 2, ncols = 1,figsize = (3.5, 6))
#figsize = (width，hight)
extent = (0,1,0,1)
plt.rc('font',family='Arial',size = 9)
plt.rcParams['xtick.direction'] = 'out'
plt.rcParams['ytick.direction'] = 'out'

x_list = ['TWI(2-4)','TWI(4-6)','TWI(6-8)','TWI(8-10)','TWI(10-12)','TWI(12-14)',
         'TWI(14-16)','TWI(16-19)','TWI(19-22)','TWI(22-25)','TWI(>25)']

x = np.arange(11)
bar_width = 0.6

plt.subplot(2,1,1)
plt.bar(x, plr_list, bar_width)
plt.xticks(x,[])
plt.yticks([0,0.2,0.4],[0,0.2,0.4])

ax1 = fig.add_subplot(2,1,2)
ax1.bar(x-bar_width/2,R_list,bar_width/2)
ax1.set_xticks(x)
ax1.set_xticklabels(x_list)
ax1.set_ylim(-550000,0)
ax1.set_yticks([-500000,-400000,-300000,-200000,-100000,0])
ax1.set_yticklabels([-50,-40,-30,-20,-10,0])
ax2 = ax1.twinx()
ax2.bar(x+bar_width/2,R_noxmin_list,bar_width/2)
ax2.set_xticks(x)
ax2.set_xticklabels([])
ax2.set_ylim(-120,0)
ax2.set_yticks([-120,-100,-80,-60,-40,-20,0])
ax2.set_yticklabels([-120,-100,-80,-60,-40,-20,0])
plt.show()