## We parse the TFRecord training data and load the data into a tf.data.Dataset

In [2]:
## Mount on google drive
from google.colab import drive
drive.mount('/content/drive/')


Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


### <font color=red> **！！Note: the Tensorflow version should be consistency with the Tensorflow version in the google AI platform, and we use Tensorflow 2.2.0 here!**

In [12]:
# !pip install tensorflow==2.2.0
import tensorflow as tf
print(tf.__version__)

2.2.0


In [13]:
import os
os.chdir("/content/drive/My Drive/Earth-Engine-with-Deep-Learning")
from utils import imgShow
import matplotlib.pyplot as plt
from models.models import UNet
import folium
import datetime

In [19]:
## Super-parameter
# training data folder and name
Image_Folder = 'EE_Image'   # !can't write into the second-level directory
tfrecord_Name = 'Train-Landsat-8-2016'
tfrecord_Path = '/content/drive/My Drive/' + Image_Folder

## TFRecord features
Bands = ['B2', 'B3', 'B4', 'B5', 'B6', 'B7']
Targets = ['impervious']
Features = Bands + Targets
Kernel_shape = [256, 256]   # Specify the shape of patches expected by the model.

# Specify model training parameters.
Batch_size = 16
Epochs = 100
Buffer_size = 2000
# Optimizer = tf.keras.optimizers.SGD(learning_rate=0.02, momentum=0.0)
Optimizer = tf.keras.optimizers.Adam(learning_rate=0.001, beta_1=0.9)
Loss = 'MeanSquaredError'
Metrics = ['RootMeanSquaredError']
# path and name of the trained model
path_pretrain = "/content/drive/My Drive/Earth-Engine-with-Deep-Learning/models/pretrain"
model_path = 'unet_v1'
path_save_model = path_pretrain + '/' + model_path + '/model'
## configure the tensorboard
log_dir = path_pretrain + '/' + model_path + "/logs"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)
log_dir
# path_save_model

'/content/drive/My Drive/Earth-Engine-with-Deep-Learning/models/pretrain/unet_v1/logs'

### Parse the tfrecord data

In [20]:
## Features description for the tf.train.Example: a Features dictionary
Columns = [
tf.io.FixedLenFeature(shape=Kernel_shape, dtype=tf.float32) for k in Features
]
Features_dict = dict(zip(Features, Columns))

def parse_tfrecord(example_proto):
    """ Parse the input `tf.Example` proto using the Features_dict above.
    """
    return tf.io.parse_single_example(example_proto, Features_dict)

def to_tuple(inputs):
    """
    Function to convert a dictionary of tensors to a tuple of (inputs, outputs).
    """
    inputsList = [inputs.get(key) for key in Features]
    stacked = tf.stack(inputsList, axis=0)
    # Convert from CHW to HWC
    stacked = tf.transpose(stacked, [1, 2, 0])
    return stacked[:,:,:len(Bands)], stacked[:,:,len(Bands):]

def get_dataset(pattern):
    """
    Function to read, parse and format to tuple a set of input tfrecord files.
    """
    glob = tf.io.gfile.glob(pattern)
    dataset = tf.data.TFRecordDataset(glob, compression_type='GZIP')  # read the TFRecord data
    dataset = dataset.map(parse_tfrecord, num_parallel_calls=5)
    dataset = dataset.map(to_tuple, num_parallel_calls=5)
    return dataset

## get training data 
tfrecord_files = tfrecord_Path + '/' + tfrecord_Name + '*'
traData = get_dataset(tfrecord_files)
traData = traData.shuffle(Buffer_size).batch(Batch_size).repeat()  ## repeat()：无限重复下去,在使用model.fit()函数时，用于多个epoch的读取。
traData

<RepeatDataset shapes: ((None, 256, 256, 6), (None, 256, 256, 1)), types: (tf.float32, tf.float32)>

### Model training

In [None]:
model = UNet(input_shape=(256,256,6), nclasses=2)
# model.summary()

model.compile(
		optimizer=tf.keras.optimizers.get(Optimizer), 
		loss=tf.keras.losses.get(Loss),
		metrics=[tf.keras.metrics.get(metric) for metric in Metrics])

model.fit(
    x=traData, 
    epochs=Epochs, 
    steps_per_epoch = 10,
    callbacks=[tensorboard_callback])

In [None]:
%load_ext tensorboard
%tensorboard --logdir models/pretrain/unet_v1/logs


In [None]:
model.save(path_save_model, save_format='tf')
