In [1]:
# -*- coding: utf-8 -*-
"""
Individual Populus Trees detection
~~~~~~~~~~~~~~~~
code by wHy
Ghent University
Haoyu.Wang@ugent.be
"""

'\nIndividual Populus Trees detection\n~~~~~~~~~~~~~~~~\ncode by wHy\nGhent University\nHaoyu.Wang@ugent.be\n'

In [2]:
import gdal
import ogr
import os
import osr
import cv2
import math
import numpy as np
from pathlib import Path
import fnmatch
from collections import deque
from tqdm import tqdm
import time
from pylab import *

In [3]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img
    """
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data.squeeze())
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read raster image
    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_bands = im_dataset.RasterCount 
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans, im_height, im_width, im_bands

def imagexy2geo(trans, row, col):
    """xy coord to geo coord
    """
    px = trans[0] + col * trans[1] + row * trans[2]
    py = trans[3] + col * trans[4] + row * trans[5]
    return px, py

def write_point_to_layer(coord, out_lyr, def_out_feature):
    """write point features in layer
    """
    point = ogr.Geometry(ogr.wkbPoint)
    point.AddPoint(coord[0], coord[1])
    outfeat = ogr.Feature(def_out_feature)
    outfeat.SetGeometry(point)
    outfeat.SetField2('type', coord[2])
    out_lyr.CreateFeature(outfeat)
    outfeat = None

def formatted_write_opencv_img(img_data, write_path, img_name, prefix, suffix, im_proj, im_geotrans):
    """formatted output for raster image 
    """
    time.sleep(1)
    if len(img_data.shape) > 2:
        output_img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
        output_img_data = output_img_data.transpose(2, 0, 1)
        output_full_path = write_path + '/' + prefix + img_name[:-4] + suffix + '.tif'
        write_img(output_full_path, im_proj, im_geotrans, output_img_data)
    else:
        output_full_path = write_path + '/' + prefix + img_name[:-4] + suffix + '.tif'
        write_img(output_full_path, im_proj, im_geotrans, img_data)

def euclidean_distance(matrix1, matrix2):
    """calculate euclidean distance
    """
    vector1 = matrix1.flatten().astype(np.float64)
    vector2 = matrix2.flatten().astype(np.float64)
    
    distance = np.sqrt(np.sum((vector1 - vector2) ** 2))
    
    return distance

def average_euclidean_distance(matrix1, matrix2, circle_mask):
    """calculate average euclidean distance
    """
    bands = np.shape(matrix1)[2]
    circle_mask_expand_bands = np.zeros((np.shape(matrix1)[0],np.shape(matrix1)[1],np.shape(matrix1)[2]))
    for i in range(bands):
        circle_mask_expand_bands[:,:,i] = circle_mask

    vector1 = matrix1.flatten().astype(np.float64)
    vector2 = matrix2.flatten().astype(np.float64)
    vector3 = circle_mask_expand_bands.flatten()
    
    pixel_distances = vector3 * np.sqrt((vector1 - vector2) ** 2)
    
    average_distance = np.sum(pixel_distances)/sum(vector3)
    
    return average_distance

def cal_patch_march_num(tmp_labels, circle_mask):
    """calculate the number of matched pixels
    """
    return np.sum(tmp_labels * circle_mask)

def cal_lockmap_match_num(lock_map_patch, circle_mask):
    """calculate the number of locked pixels
    """
    return np.sum(lock_map_patch * circle_mask)


def geo2imagexy(geoX, geoY, im_geotrans):
    """geo coord to xy coord
    """
    g0 = float(im_geotrans[0])
    g1 = float(im_geotrans[1])
    g2 = float(im_geotrans[2])
    g3 = float(im_geotrans[3])
    g4 = float(im_geotrans[4])
    g5 = float(im_geotrans[5])

    x = (geoX*g5 - g0*g5 - geoX*g2 + g3*g2)/(g1*g5 - g4*g2)
    y = (geoY - g3 - geoX*g4)/ g5

    return x, y

def write_to_log(log_filename, log_message):
    """write to log file
    """
    with open(log_filename, 'a') as log_file:
        log_entry = f'{log_message}'
        log_file.write(log_entry)

def truncated_linear_stretch(image, truncated_value, max_out = 255, min_out = 0):
    """image stretching
    """
    def gray_process(gray):
        nonzero_pixels = gray[gray != 0]

        if len(nonzero_pixels) == 0:
            return gray

        truncated_down = np.percentile(nonzero_pixels, truncated_value)
        truncated_up = np.percentile(nonzero_pixels, 100 - truncated_value)
        gray = (gray - truncated_down) / (truncated_up - truncated_down) * (max_out - min_out) + min_out 
        gray[gray < min_out] = min_out
        gray[gray > max_out] = max_out
        if(max_out <= 255):
            gray = np.uint8(gray)
        elif(max_out <= 65535):
            gray = np.uint16(gray)
        return gray

    if(len(image.shape) == 3):
        image_stretch = []
        for i in range(image.shape[0]):
            gray = gray_process(image[i])
            image_stretch.append(gray)
        image_stretch = np.array(image_stretch)
    else:
        image_stretch = gray_process(image)
    return image_stretch


In [4]:
'''parameters'''
dataset = 'd2' # testset [d0 or d1 or d2]
sr_img_path = r'../data/Test_sets/1-clip_img_' + dataset # image path
semantic_segmentation_result_path = r'../data/Test_sets/2-semantic_segmentation_' + dataset # semantic segmentation result path
output_shp_path = r'../output/' + dataset # output path
template_path = r'../data/Templates' # template path
log_file_name = r'../log/log.txt' # log path

glt = 0.9 # global lock threshold ***core parameter
mlt = 0.8 # minimal lock threshold ***core parameter


'''*** NOT recommended to modify the following parameters ***'''
foreground_value = 1 # the value of Populus pixels in semantic segmentation result
shandow_threshold = 160 # shadow threshold
minimum_threshold = 3 # single Populus detection threshold
skip_threshold = 0.3 # valid pixel threshold

Tempalte = []  # big first
#Tempalte.append(np.load(template_path + '/1_19_21.npy'))
#Tempalte.append(np.load(template_path + '/1_17_19.npy'))
Tempalte.append(np.load(template_path + '/1_15_17.npy')) # [7-17] achieves best performance 
Tempalte.append(np.load(template_path + '/1_13_15.npy'))
Tempalte.append(np.load(template_path + '/1_11_13.npy'))
Tempalte.append(np.load(template_path + '/1_9_11.npy'))
Tempalte.append(np.load(template_path + '/1_7_9.npy'))
Tempalte.append(np.load(template_path + '/1_5_7.npy'))
# Tempalte.append(np.load(template_path + '/1_0_5.npy'))

T_num = len(Tempalte) # Number of used templates
c_size = []
for i in range(T_num):
    c_size.append(np.shape(Tempalte[i])[0])

# mask generation
circle_masks = []
circle_valid_sums = []

for k in range(T_num):
    circle_mask = np.zeros((c_size[k], c_size[k]), dtype=np.int32)
    circle_radius = c_size[k]//2
    circle_center = (c_size[k]//2, c_size[k]//2)
    for i in range(c_size[k]):
        for j in range(c_size[k]):
            distance = np.sqrt((i - circle_center[0])**2 + (j - circle_center[1])**2)
            if distance <= circle_radius:
                circle_mask[i, j] = 1
    circle_masks.append(circle_mask)

suit_map_binary_values = []
for i in range(T_num): 
    suit_map_binary_values.append(10)

max_distances = [] # the max searching distance of crown center 
for i in range(T_num): 
    max_distances.append(c_size[i])

lock_thresholds = []
for i in range(T_num): 
    lock_thresholds.append(int(math.pi * (c_size[i]//2) **2) * glt)

if T_num>3:
    suit_map_binary_values[T_num-1] = 10
    circle_masks[T_num-1][:] = 1
    lock_thresholds[T_num-1] = c_size[T_num-1] **2 * mlt

In [5]:
'''write log'''
from datetime import datetime
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
write_to_log(log_file_name, current_time)
write_to_log(log_file_name, '\nTmp:'+ str(c_size) +', suit_binaryT:' +str(suit_map_binary_values) +', shadowT:'+str(shandow_threshold)+', minimumT:'+str(minimum_threshold)+', skipT:' + str(skip_threshold)+', maxD:' + str(max_distances) +', lockT:' + str(lock_thresholds) +'\n\n')

In [6]:
listpic = fnmatch.filter(os.listdir(sr_img_path), '*.tif')

for img_name in tqdm(listpic): # for each test image
    img_full_path = sr_img_path + '/' + img_name
    sen_seg_full_path = semantic_segmentation_result_path + '/' + img_name

    # read image
    data_ss, im_proj, im_geotrans = read_img(sen_seg_full_path)[:3]
    data_img = read_img(img_full_path)[0]

    data_img = cv2.cvtColor(data_img.transpose(1, 2, 0), cv2.COLOR_RGB2BGR) # (c, h, w) -> (h, w, c) & RGB->BGR

    # simple preprocessing
    height, width, channel = np.shape(data_img)
    data_ss = cv2.resize(data_ss, (width, height))

    # masking
    binary_mask = (data_ss == foreground_value).astype(np.uint8)
    masked_img = np.zeros((height, width, channel), dtype=np.uint8)

    for i in range(channel):
        masked_img[:,:,i] = data_img[:,:,i] * binary_mask

    # shadow removing
    rt = masked_img[:,:,2].copy()
    rt[rt<=shandow_threshold] = 0
    rt[rt>shandow_threshold] = 255

    # construct connected components in the binary mask
    opening = rt.copy()
    markers_com = opening.copy()

    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(markers_com, connectivity=4)

    '''calculate suit map'''
    suit_map = np.zeros((T_num, height, width), dtype=np.float64)
    suitmap_for_vis = suit_map.copy()
    minimum_threshold = minimum_threshold
    populus_center = []
    skip_threshold = skip_threshold

    for label in range(1, num_labels):  # skip background 0
        area = stats[label, cv2.CC_STAT_AREA]
        x, y, p_width, p_height = stats[label, cv2.CC_STAT_LEFT], stats[label, cv2.CC_STAT_TOP], stats[label, cv2.CC_STAT_WIDTH], stats[label, cv2.CC_STAT_HEIGHT]
        
        # identify single populus
        if p_width<c_size[0] and p_height<c_size[0]:
            if p_width > minimum_threshold and p_height > minimum_threshold:
                populus_center.append([y+p_height//2, x+p_width//2, -1])
                continue
            else:
                continue

        # generate suit map for clustered populus
        for k in range(T_num):
            for i in range(0, p_height):
                for j in range(0, p_width):
                    if y+i+c_size[k] < height and x+j+c_size[k] < width:
                        tmp_labels = labels[y+i: y+i+c_size[k], x+j: x+j+c_size[k]].copy()
                        tmp_labels[tmp_labels!=label] = 0
                        tmp_labels[tmp_labels==label] = 1
                        if cal_patch_march_num(tmp_labels, circle_masks[k]) >= skip_threshold * sum(circle_masks[k]):
                            tmp_img = data_img[y+i: y+i+c_size[k], x+j: x+j+c_size[k], :].copy()
                            average_distance = average_euclidean_distance(Tempalte[k], tmp_img, circle_masks[k])
                            average_distance = 255-average_distance
                            suit_map[k, y+i, x+j] = average_distance
                            suitmap_for_vis[k, y+i+c_size[k]//2, x+j+c_size[k]//2] = average_distance
                        else:
                            continue # skip if the ratio of foreground < threshold
                    else:
                        continue

    '''preprocess for individual trees detection'''
    nonzero_values = []
    for i in range(T_num):
        tmp_suit_map = suit_map[i,:,:]
        nonzero_value = tmp_suit_map[tmp_suit_map > 0]
        nonzero_values.append(nonzero_value)

    threshold_values = []
    for i in range(T_num):
        if (len(nonzero_values[i]) > 1):
            threshold_values.append(np.percentile(nonzero_values[i], suit_map_binary_values[i]))
        else:
            threshold_values.append(1)

    lock_map = opening.copy() # initialize the lock map 
    lock_map[lock_map==255] = 1    

    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)] # searching direction

    '''individual trees detection'''
    for k in range(T_num):
        for i in range(0, height-c_size[k]):
            for j in range(0, width-c_size[k]):

                visited = np.zeros((height, width), dtype=bool)
                max_distance = max_distances[k]
                max_value = -999

                if suit_map[k, i, j] > threshold_values[k] and not visited[i, j] and lock_map[i,j]==1 and cal_lockmap_match_num(lock_map[i: i+c_size[k], j: j+c_size[k]], circle_masks[k]) >= lock_thresholds[k]:
                    # constrained search #1
                    queue = deque([(i, j, 0)])
                    visited[i][j] = True
                    max_value = suit_map[k, i, j]
                    max_coords = (i, j)

                    while queue:
                        h, w, distance = queue.popleft()
                        if distance > max_distance:
                            continue

                        for dh, dw in directions:
                            nh, nw = h + dh, w + dw
                            if 0 <= nh < height and 0 <= nw < width and suit_map[k, nh, nw] > threshold_values[k] and not visited[nh][nw] and lock_map[nh,nw]==1 and cal_lockmap_match_num(lock_map[nh: nh+c_size[k], nw: nw+c_size[k]], circle_masks[k]) >= lock_thresholds[k]:
                                # constrained search #2
                                queue.append((nh, nw, distance + 1))
                                visited[nh][nw] = True
                                if suit_map[k, nh, nw] > max_value:
                                    max_value = suit_map[k, nh, nw]
                                    max_coords = (nh, nw)
                    
                    # update lock map
                    for m in range(c_size[k]):
                        for n in range(c_size[k]):
                            if circle_masks[k][m][n] == 1:
                                lock_map[max_coords[0]+m, max_coords[1]+n] = 0
                    
                    # get the crown center if queue is empty
                    populus_center.append([max_coords[0]+c_size[k]//2, max_coords[1]+c_size[k]//2, k])   

    '''ouput the detection result'''
    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")

    ogr.RegisterAll()  # register all dirvers

    driver = ogr.GetDriverByName('ESRI Shapefile')

    prj = osr.SpatialReference()
    prj.ImportFromWkt(im_proj)  # read the projection info from raster image

    out_shp_full_path = output_shp_path + '/' +img_name[:-4] + '_populus_infer.shp'
    if Path(out_shp_full_path).exists():
        driver.DeleteDataSource(out_shp_full_path)
    out_ds = driver.CreateDataSource(out_shp_full_path)
    out_lyr = out_ds.CreateLayer(
        out_shp_full_path, prj, ogr.wkbPoint)
    def_out_feature = out_lyr.GetLayerDefn()  # read the feature type

    oField = ogr.FieldDefn('type', ogr.OFTInteger)
    out_lyr.CreateField(oField)

    for i in range(len(populus_center)):
        point = ogr.Geometry(ogr.wkbPoint)
        x_coord, y_coord = imagexy2geo(im_geotrans, populus_center[i][0], populus_center[i][1])
        write_point_to_layer([x_coord, y_coord, populus_center[i][2]], out_lyr, def_out_feature)
    out_ds.Destroy()

100%|██████████| 43/43 [05:13<00:00,  7.28s/it]
