In [46]:
### mount on google drive
from google.colab import drive
drive.mount('/content/drive/')
import os
os.chdir("/content/drive/My Drive/WatNet/notebooks")
# !pip install rasterio

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [47]:
import os
os.chdir('..')
import time
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from dataloader.tfrecord_io import parse_image,parse_shape,toPatchPair
from utils.acc_eval import acc_patch
from utils.imgShow import imgShow
from dataloader.dataloader import get_scene, get_patch, image_aug
from models.seg_model.watnet import watnet


In [48]:
## super parameters
patch_size = 512
epochs = 100
lr = 0.005
batch_size = 4
buffer_size = 200

In [4]:
# ## get training data
# # previous, 8s/epoch
# # training dataset (!!your dataset(.tif file) directory)
# dir_tra_scene = config.root + '/data/dataset-2/tra_scene'
# dir_tra_truth = config.root + '/data/dataset-2/tra_truth'
# scenes, truths = get_scene(dir_tra_scene, dir_tra_truth, patch_size)
# traData = get_patch(scenes, truths, patch_size, batch_size, buffer_size)
# traData

## Load and parse the tfrecord data

In [49]:
### training dataset (!!your dataset(.tfrecords file) directory)
path_tra_data = 'data/tfrecord-s2/tra_data.tfrecords'
traData = tf.data.TFRecordDataset(path_tra_data)
traData = traData.map(parse_image).map(parse_shape)\
            .cache()\
            .map(toPatchPair)\
            .map(image_aug)

traData = traData.batch(batch_size)
# traData = traData.shuffle(buffer_size).batch(batch_size)
traData


<BatchDataset shapes: ((None, 512, 512, 6), (None, 512, 512, 1)), types: (tf.float32, tf.float32)>

In [55]:
# for i in range(5):
#     start = time.time()
#     i = 1
#     for tra_patch, val_path in traData:
#         i+=1
#     imgShow(tra_patch.numpy()[0])
#     plt.show()
#     print('num:', i)
#     print('time:', time.time()-start)


In [63]:
## model configuration
model = watnet(input_shape=(patch_size, patch_size, 6), nclasses=2)
model_loss = tf.keras.losses.BinaryCrossentropy()
model_optimizer = tf.keras.optimizers.Adam(lr)


*** Building DeepLabv3Plus Network ***
*** Output_Shape => (None, 512, 512, 1) ***


In [57]:
class miou_binary(tf.keras.metrics.MeanIoU):
    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred = tf.where(y_pred>0.5, 1, 0)
        super().update_state(y_true, y_pred, sample_weight)

tra_oa = tf.keras.metrics.BinaryAccuracy('tra_oa')
tra_miou = miou_binary(num_classes=2,name='tra_miou')
tra_loss_tracker = tf.keras.metrics.Mean(name="tra_loss")
test_oa = tf.keras.metrics.BinaryAccuracy('test_oa')
test_miou = miou_binary(num_classes=2,name='test_miou')
test_loss_tracker = tf.keras.metrics.Mean(name="test_loss")



In [82]:
'''------train step------'''
@tf.function
def train_step(model, loss_fun, optimizer, x, y):
    with tf.GradientTape() as tape:
        y_pre = model(x, training=True)
        loss = loss_fun(y, y_pre)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    tra_loss_tracker.update_state(loss)
    tra_oa.update_state(y, y_pre)
    tra_miou.update_state(y, y_pre)
    return tra_loss_tracker.result(), tra_oa.result(), tra_miou.result()

'''------train loops------'''
def train_loops(model, loss_fun, optimizer, tra_dset, epochs):
    max_miou_pre = 0.8
    tra_loss_plot = []
    tra_oa_plot = []
    tra_miou_plot = []
    for epoch in range(epochs):
        start = time.time()
        '''---train the model---'''
        for x_batch, y_batch in tra_dset:
            # x_batch=x_batch[2]   ##!!note: x_batch[2] for single-scale model
            tra_loss_epoch,tra_oa_epoch,tra_miou_epoch = train_step(model, loss_fun, optimizer, x_batch, y_batch)
        tra_loss_plot.append(tra_loss_epoch)
        tra_oa_plot.append(tra_oa_epoch)
        tra_miou_plot.append(tra_miou_epoch)
        
        # '''---test the model---'''
        # for x_batch, y_batch in test_dset:
        #     # x_batch=x_batch[2]  ##!note: x_batch[2] for single-scale model
        #     test_loss_epoch, test_oa_epoch, test_miou_epoch = test_step(model, loss_fun, x_batch, y_batch)
        tra_loss_tracker.reset_states(), tra_oa.reset_states(), tra_miou.reset_states()
        # config.test_loss_tracker.reset_states(), config.test_oa.reset_states(), config.test_miou.reset_states()
        print('epoch {}: traLoss:{:.3f}, traOA:{:.2f}, traMIoU:{:.2f}, time:{:.0f}s'\
              .format(epoch + 1, tra_loss_epoch, tra_oa_epoch, tra_miou_epoch, time.time() - start))
    return [tra_loss_plot, tra_oa_plot,tra_miou_plot]


In [83]:
## training
tra_records = train_loops(model, \
                        loss_fun=model_loss, \
                        optimizer=model_optimizer, \
                        tra_dset=traData, \
                        epochs=10)



epoch 1: traLoss:0.078, traOA:0.97, traMIoU:0.93, time:10s
epoch 2: traLoss:0.074, traOA:0.97, traMIoU:0.93, time:7s
epoch 3: traLoss:0.068, traOA:0.98, traMIoU:0.94, time:7s
epoch 4: traLoss:0.075, traOA:0.97, traMIoU:0.93, time:7s
epoch 5: traLoss:0.075, traOA:0.97, traMIoU:0.93, time:7s
epoch 6: traLoss:0.061, traOA:0.98, traMIoU:0.94, time:7s
epoch 7: traLoss:0.064, traOA:0.98, traMIoU:0.94, time:7s
epoch 8: traLoss:0.085, traOA:0.97, traMIoU:0.92, time:7s
epoch 9: traLoss:0.070, traOA:0.97, traMIoU:0.93, time:7s
epoch 10: traLoss:0.073, traOA:0.97, traMIoU:0.92, time:7s


In [None]:
### model saving
# path_save = 'models/pretrained_model/watnet_tmp'
# model.save(path_save)

