### Predict Sentinel tiles using SVM/RF/ANN Models

In [1]:
import time
import pickle
import numpy as np
from osgeo import gdal
# from tensorflow.keras.models import load_model # uncomment if you are using ANN
import matplotlib.pyplot as plt

## Load Models

In [4]:
# Load SVM Model
model_name = './models/svm_hollstein/svm_c100_rbf_g1_hollstein_13i_6o.sav'

svm = pickle.load(open(model_name, 'rb'))

## Read and predict tile

In [5]:
# Function to predict the cloud mask
# Inputs: Filename of tif with 13 bands and loaded model (RF or ANN or SVM)
def predict_tile_ann(filename, model):
    
    test_image = gdal.Open(filename)
    
    test_array = test_image.ReadAsArray()
    test_array2 = test_array.transpose(1, 2, 0)/10000
    
    shp = test_array2.shape
    input_array = test_array2.reshape(shp[0]*shp[1],shp[2])
    predicted_array = model.predict(input_array)
#     predicted_array = model.predict_proba(input_array)
    
    if len(predicted_array.shape) > 1:
        predicted_array = np.argmax(predicted_array, axis=1) #this line is ANN one-hot encoding labels

    predicted_array2 = predicted_array.reshape(shp[0],shp[1])
    band_2 = test_array2[:, :, 1:2].reshape(shp[0],shp[1])
    final_array = np.where(band_2 == 0, np.nan, predicted_array2)
    
    return final_array

In [None]:
# Predict tile
start = time.time()
# tile_name = "./tiff_images/S2A_MSIL1C_20220416T074621_N0400_R135_T35JPN_20220416T094522.tif"
tile_name = "./tiff_images/S2A_MSIL1C_20220124T053111_N0301_R105_T43QEV_20220124T062018.tif"
output_array = predict_tile_ann(tile_name, svm)
end = time.time()
print("Training time is :", end-start)

In [None]:
# save classified image as tif image
output_image = "./results/S2A_MSIL1C_20220124T053111_N0301_R105_T43QEV_20220124T062018_0.tif"
outdriver = gdal.GetDriverByName("GTiff")

shp = output_array.shape
outdata = outdriver.Create(output_image, shp[1], shp[0], 1, gdal.GDT_Float32)
outdata.GetRasterBand(1).WriteArray(final_array)

test_image = gdal.Open(tile_name)
trans = test_image.GetGeoTransform()
proj = test_image.GetProjection()
outdata.SetGeoTransform(trans)
outdata.SetProjection(proj)
del outdata