## Load libraries and modules

In [None]:
import tensorflow as tf

In [None]:
#check the GPU colab assigns to you
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

In [None]:
tf.device('/device:GPU:2')

In [None]:

%matplotlib inline
%pylab inline
pylab.rcParams['figure.figsize'] = (15, 10)


from tensorflow.keras.models import *
from tensorflow.keras.layers import *

from tensorflow.keras import regularizers

import os
import rasterio
#import rasterio.warp             # Reproject raster samples
from rasterio import windows
#import geopandas as gps
#import PIL.Image
#import PIL.ImageDraw

import gc

from pathlib import Path

from itertools import product
from tqdm import tqdm

import geopandas as gpd

# !pip install ipython-autotime
# %load_ext autotime

In [None]:
#set the sys path where the modules locates
import sys
sys.path.insert(0,"core")

#If you are using Google Colaboratory, modify the path here
#sys.path.insert(0,"/content/drive/MyDrive/Colab/zijingwu-Satellite-based-monitoring-of-wildebeest/core")
from preprocess import *
from data_generator import DataGenerator, SimpleDataGenerator

from model import *

from evaluation import *

from visualization import *

import importlib

from predict import *

## Load the satellite images

In [None]:

Data_folder = "/home/zijing/wildebeest/data/test"

image_path = Data_folder

Folder = "/home/zijing/wildebeest/tmp"

Output_dir = os.path.join(Data_folder,"predict_test")
Final_Output_dir =  os.path.join(Data_folder,"predict_test_combine")

WEIGHT_PATH = os.path.join(Folder,'checkpoint/weights')

In [None]:
target_images = get_images_to_predict(image_path)

In [None]:
NUM = 2
PATCH_SIZE = 336
TILE_MAX_SIZE = PATCH_SIZE * NUM

INPUT_BANDS = [0,1,2]
NUMBER_BANDS=len(INPUT_BANDS)

CONTRAST = False
fold_nums = 5

## Detect the wildebeest on the images

In [None]:

cluster_size = 16
nfold = 5


if not os.path.exists(Output_dir):
    os.makedirs(Output_dir)
if not os.path.exists(Final_Output_dir):
    os.makedirs(Final_Output_dir)
    
for ti in target_images:
    print(ti)
    f = ti
    file_name = os.path.split(f)[1]
    img_name, file_extension = os.path.splitext(file_name)
    print(img_name)

    final_shp_path = os.path.join(Final_Output_dir, img_name+'.shp')
    final_mask_path = os.path.join(Final_Output_dir, img_name+'.tif')

    if Path(final_shp_path).is_file() == True:
      print(f"Prediction already exists. Skip.")
      continue

    with rasterio.open(f) as src:


        model = unet(pretrained_weights=None, input_size=(PATCH_SIZE,PATCH_SIZE,NUMBER_BANDS), regularizers = regularizers.l2(0.0001))
        detectedMask = detect_wildebeest(model, WEIGHT_PATH, src, width=PATCH_SIZE, height=PATCH_SIZE, stride = 256,
                            batch_size=12, stretch=CONTRAST, num_folds=nfold) # WIDTH and HEIGHT should be the same and in this case Stride is 50 % width
        # visualize_prediction(detectedMask)
        #Write the mask to file
        # visualize_data(np.moveaxis(np.uint8(src.read()), 0,-1),np.expand_dims(detectedMask, axis=2))
        writeResultsToDisk(detectedMask, src, src.meta['transform'], final_shp_path, None, cluster_size)
        #Write the mask to file
        # visualize_data(np.moveaxis(np.uint8(src.read()), 0,-1),np.expand_dims(detectedMask, axis=2))

In [None]:
# If the satellite image size is too large and you would like to process in tiles:

cluster_size = 16
nfold = 5


if not os.path.exists(Output_dir):
    os.makedirs(Output_dir)
if not os.path.exists(Final_Output_dir):
    os.makedirs(Final_Output_dir)
    
for ti in target_images:
    print(ti)
    f = ti
    file_name = os.path.split(f)[1]
    img_name, file_extension = os.path.splitext(file_name)
    print(img_name)

    ti_Output_dir = os.path.join(Output_dir, img_name)
    if not os.path.exists(ti_Output_dir):
        os.makedirs(ti_Output_dir)
    final_shp_path = os.path.join(Final_Output_dir, img_name+'.shp')
    final_mask_path = os.path.join(Final_Output_dir, img_name+'.tif')

    if Path(final_shp_path).is_file() == True:
      print(f"Prediction already exists. Skip.")
      continue

    with rasterio.open(f) as src:


        model = unet(pretrained_weights=None, input_size=(PATCH_SIZE,PATCH_SIZE,NUMBER_BANDS), regularizers = regularizers.l2(0.0001))
        detect_wildebeest_tile(model, WEIGHT_PATH, src, ti_Output_dir, f, tile_width=5000, tile_height=5000, 
                               width=PATCH_SIZE, height=PATCH_SIZE, stride = 256,
                               batch_size=12, stretch=CONTRAST, num_folds=nfold,
                               mask_outpath=None, cluster_size=cluster_size)
                               
        
    file_list = []
    for root, dirs, files in os.walk(ti_Output_dir, topdown=False):
        for name in files:
          _, file_extension = os.path.splitext(name)
          if file_extension == '.shp':
            #print(name)
                  
            points = gpd.read_file(os.path.join(root,name))
            file_list.append(points)
    # print(len(file_list))
    rdf = gpd.pd.concat(file_list, ignore_index=True)
    
    rdf.to_file(final_shp_path)   
    print(f"Number of detected wildebeest on image {img_name} is: {rdf.count()['id']}")
    # del detectedMask
    gc.collect()