In [None]:
# ## Mount on google drive
# from google.colab import drive
# drive.mount('/content/drive/')

In [None]:
import config
import os
os.chdir(config.root)
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from utils.geotif_io import readTiff,writeTiff
from utils.imgShow import imgShow
from utils.imgPatch import imgPatch


In [None]:
# satellite images and save path
names_scene = ['l5_scene_12','l5_scene_14','l5_scene_18','l7_scene_05','l7_scene_08','l7_scene_14','l7_scene_18',
                'l7_scene_20', 'l8_scene_04','l8_scene_06','l8_scene_10','l8_scene_14','l8_scene_15','l8_scene_18',
                'l8_scene_20',
                'l5_scene_21','l5_scene_22','l5_scene_23','l5_scene_24','l5_scene_25','l5_scene_26','l5_scene_27',
                'l5_scene_28','l5_scene_29','l5_scene_30','l7_scene_21','l7_scene_22','l7_scene_23','l7_scene_24',
                'l7_scene_25','l7_scene_26','l7_scene_27', 'l7_scene_28','l7_scene_29','l7_scene_30','l8_scene_21',
                'l8_scene_22','l8_scene_23','l8_scene_24','l8_scene_25','l8_scene_26','l8_scene_27',
                'l8_scene_28','l8_scene_29','l8_scene_30']
paths_img = [config.root + '/data/dset-l578/scene/' + name + '.tif' for name in names_scene]
paths_wat_map = [config.root + '/data/dset-l578/scene/'+ name + '_pred.tif' for name in names_scene]
len(paths_wat_map)



In [None]:
## ----- model ------
# pre-trained model
path_model =  config.root + '/model/pretrained/watnetv2.h5'
# super parameters
model = tf.keras.models.load_model(path_model, compile=False)  # load model


In [None]:
for i in range(len(names_scene)):
# for i in range(1):
    print('image:', paths_img[i])
    ### --- read tif
    # satellite images 
    img, img_info = readTiff(path_in=paths_img[i])
    img = np.float32(np.clip(img/10000, a_min=0, a_max=1))     ### normalization

    ### ---- surface water mapping from sentinel-2 image
    imgPat_ins = imgPatch(img=img, patch_size=512, edge_overlay = 160)
    patch_list, start_list, img_patch_row, img_patch_col = imgPat_ins.toPatch()
    result_patch_list = [model(patch[np.newaxis, :]) for patch in patch_list]
    result_patch_list = [np.squeeze(patch, axis = 0) for patch in result_patch_list]
    pro_map = imgPat_ins.toImage(result_patch_list, img_patch_row, img_patch_col)
    cla_map = np.where(pro_map>0.5, 1, 0)
    cla_map = np.int8(cla_map)

    ### --- save the result
    writeTiff(im_data=cla_map,  
              im_geotrans=img_info['geotrans'], 
              im_geosrs=img_info['geosrs'], 
              path_out=paths_wat_map[i])
    print('saved image -->:', paths_wat_map[i])
    

In [None]:
## show the image and the prediction map
scene_name = 'l5_scene_12'
path_img = 'data/dset-l578/scene/'+ scene_name+'.tif'
path_pred = 'data/dset-l578/scene/'+ scene_name + '_pred.tif'
img, _ = readTiff(path_in=path_img)
pred, _ = readTiff(path_in=path_pred)
plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
imgShow(img, color_bands=(3,2,1), clip_percent=2)
plt.subplot(1,2,2)
imgShow(pred, color_bands=(0,0,0), clip_percent=1)