In [385]:
# -*- coding: utf-8 -*-

"""
Individual Populus Trees detection
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
Ghent University
Haoyu.Wang@ugent.be
"""

'\n胡杨拆解计数\n~~~~~~~~~~~~~~~~\ncode by wHy\nAerospace Information Research Institute, Chinese Academy of Sciences\nGhent University\nHaoyu.Wang@ugent.be\n'

In [386]:
from tqdm import tqdm
import gdal
import ogr
import os
import osr

from PIL import Image
from pylab import *
from scipy.ndimage import filters
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import math
from skimage import io, color
import numpy as np
from skimage import img_as_ubyte
from skimage import img_as_float
from tqdm import tqdm

from pathlib import Path

import fnmatch
from collections import deque

import time

In [387]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img
    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data
    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data.squeeze())
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img
    Args:
        sr_img: The full path of the original image
    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_bands = im_dataset.RasterCount 
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans, im_height, im_width, im_bands

def imagexy2geo(trans, row, col):
    '''
    根据GDAL的六参数模型将影像图上坐标（行列号）转为投影坐标或地理坐标（根据具体数据的坐标系统转换）
    :param dataset: GDAL地理数据
    :param row: 像素的行号
    :param col: 像素的列号
    :return: 行列号(row, col)对应的投影坐标或地理坐标(x, y)
    '''
    px = trans[0] + col * trans[1] + row * trans[2]
    py = trans[3] + col * trans[4] + row * trans[5]
    return px, py

def write_point_to_layer(coord, out_lyr, def_out_feature):
    """
    write point features in layer
    """
    point = ogr.Geometry(ogr.wkbPoint)
    point.AddPoint(coord[0], coord[1])
    outfeat = ogr.Feature(def_out_feature)  # 定义输出要素
    outfeat.SetGeometry(point)  # 设置要素几何
    outfeat.SetField2('type', coord[2]) # 写入类别
    out_lyr.CreateFeature(outfeat)  # 图层新建要素
    outfeat = None

def formatted_write_opencv_img(img_data, write_path, img_name, prefix, suffix, im_proj, im_geotrans):
    """
    opencv格式化输出
    """
    time.sleep(1)
    if len(img_data.shape) > 2:
        output_img_data = cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB)
        output_img_data = output_img_data.transpose(2, 0, 1)
        output_full_path = write_path + '/' + prefix + img_name[:-4] + suffix + '.tif'
        write_img(output_full_path, im_proj, im_geotrans, output_img_data)
    else:
        output_full_path = write_path + '/' + prefix + img_name[:-4] + suffix + '.tif'
        write_img(output_full_path, im_proj, im_geotrans, img_data)

def euclidean_distance(matrix1, matrix2):
    """计算两个矩阵的欧式距离
    """
    # 将两个矩阵展平为向量
    vector1 = matrix1.flatten().astype(np.float64)
    vector2 = matrix2.flatten().astype(np.float64)
    
    # 计算欧氏距离
    distance = np.sqrt(np.sum((vector1 - vector2) ** 2))
    
    return distance

def average_euclidean_distance(matrix1, matrix2, circle_mask):
    """计算两个矩阵的平均欧式距离
    """
    # 将两个矩阵展平为向量
    bands = shape(matrix1)[2]
    circle_mask_expand_bands = np.zeros((shape(matrix1)[0],shape(matrix1)[1],shape(matrix1)[2]))
    for i in range(bands):
        circle_mask_expand_bands[:,:,i] = circle_mask

    vector1 = matrix1.flatten().astype(np.float64)
    vector2 = matrix2.flatten().astype(np.float64)
    vector3 = circle_mask_expand_bands.flatten()
    
    # 计算每个像素之间的欧氏距离
    pixel_distances = vector3 * np.sqrt((vector1 - vector2) ** 2)
    
    # 计算平均欧氏距离
    average_distance = np.sum(pixel_distances)/sum(vector3)
    
    return average_distance

def cal_patch_march_num(tmp_labels, circle_mask):
    """返回匹配的像素总数
    """
    return np.sum(tmp_labels * circle_mask)

def cal_lockmap_match_num(lock_map_patch, circle_mask):
    """返回匹配锁定图像素总数
    """
    return np.sum(lock_map_patch * circle_mask)


def geo2imagexy(geoX, geoY, im_geotrans):
    # 地理坐标系转图上坐标系
    g0 = float(im_geotrans[0])
    g1 = float(im_geotrans[1])
    g2 = float(im_geotrans[2])
    g3 = float(im_geotrans[3])
    g4 = float(im_geotrans[4])
    g5 = float(im_geotrans[5])

    x = (geoX*g5 - g0*g5 - geoX*g2 + g3*g2)/(g1*g5 - g4*g2)
    y = (geoY - g3 - geoX*g4)/ g5

    return x, y

def write_to_log(log_filename, log_message):
    # 打开日志文件，如果不存在则创建
    with open(log_filename, 'a') as log_file:
        # 获取当前时间
        
        # 构造要写入日志文件的内容
        log_entry = f'{log_message}'
        
        # 将日志内容写入文件
        log_file.write(log_entry)

def truncated_linear_stretch(image, truncated_value, max_out = 255, min_out = 0):
    # 图像拉伸
    def gray_process(gray):
        nonzero_pixels = gray[gray != 0] # 0不参与计算

        if len(nonzero_pixels) == 0:
        # 如果所有像素都是0，返回原始图像
            return gray

        truncated_down = np.percentile(nonzero_pixels, truncated_value)
        truncated_up = np.percentile(nonzero_pixels, 100 - truncated_value)
        gray = (gray - truncated_down) / (truncated_up - truncated_down) * (max_out - min_out) + min_out 
        gray[gray < min_out] = min_out
        gray[gray > max_out] = max_out
        if(max_out <= 255):
            gray = np.uint8(gray)
        elif(max_out <= 65535):
            gray = np.uint16(gray)
        return gray
    
    #  如果是多波段
    if(len(image.shape) == 3):
        image_stretch = []
        for i in range(image.shape[0]):
            gray = gray_process(image[i])
            image_stretch.append(gray)
        image_stretch = np.array(image_stretch)
    #  如果是单波段
    else:
        image_stretch = gray_process(image)
    return image_stretch


In [388]:
# 每做一组实验需要：
# 1. 检查sr_img_path，semantic_segmentation_result_path是否正确
# 2. 修改output_img_path和output_shp_path最后的路径(最后1个或者2个)
# 3. 修改实验相关参数

sr_img_path = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\1-evaluation\1-clip_img_d2'
semantic_segmentation_result_path = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\1-evaluation\2-ss_d2'
output_img_path = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\1-evaluation\4-predict\TS_7-21_d2'
output_shp_path = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\1-evaluation\4-predict\TS_7-21_d2'
glt = 0.9
mlt = 0.8

foreground_value = 1
template_path = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\画图\图6-模版结果图'
log_file_name = r'C:\Users\75198\OneDrive\论文\SCI-4 Populus counting\画图\实验日志.txt'

if not os.path.exists(output_img_path):
    os.mkdir(output_img_path)
if not os.path.exists(output_shp_path):
    os.mkdir(output_shp_path)

In [389]:
'''读取模板'''
'''模板由plt_template.ipynb计算生成'''

# 注意，必须由大到小排列
Tempalte = []
#Tempalte.append(np.load(template_path + '/1_19_21.npy'))
#Tempalte.append(np.load(template_path + '/1_17_19.npy'))
Tempalte.append(np.load(template_path + '/1_15_17.npy')) # opencv BGR (h,w,c) uint8 
Tempalte.append(np.load(template_path + '/1_13_15.npy'))
Tempalte.append(np.load(template_path + '/1_11_13.npy'))
Tempalte.append(np.load(template_path + '/1_9_11.npy'))
Tempalte.append(np.load(template_path + '/1_7_9.npy'))
Tempalte.append(np.load(template_path + '/1_5_7.npy'))
# Tempalte.append(np.load(template_path + '/1_0_5.npy'))

T_num = len(Tempalte) # 模板数量

print(shape(Tempalte[0]), Tempalte[0].dtype)
c_size = []
for i in range(T_num):
    c_size.append(shape(Tempalte[i])[0])
print(c_size)

'''生成对应的圆蒙版'''
circle_masks = []
circle_valid_sums = []

# 蒙版
for k in range(T_num):
    circle_mask = np.zeros((c_size[k], c_size[k]), dtype=np.int32)
    circle_radius = c_size[k]//2
    circle_center = (c_size[k]//2, c_size[k]//2)  # 圆心的坐标，这里假设圆心位于矩阵中心
    # 遍历矩阵中的每个像素，并计算与圆心的距离
    for i in range(c_size[k]): # 注意修改c_size
        for j in range(c_size[k]):
            distance = np.sqrt((i - circle_center[0])**2 + (j - circle_center[1])**2)
            if distance <= circle_radius:
                circle_mask[i, j] = 1
    circle_masks.append(circle_mask)


(21, 21, 3) uint8
[21, 19, 17, 15, 13, 11, 9, 7]


In [390]:
shandow_threshold = 160 # 识别为阴影的阈值
minimum_threshold = 3 # 单株胡杨的最低识别阈值 长宽均大于该阈值才被视为单株胡杨
skip_threshold = 0.3 # 任意一个patch的有效像素的比例必须大于该值，才被视为有效像素 230920对于小模板是否应该调低这个阈值待测试

suit_map_binary_values = []
for i in range(T_num): 
    suit_map_binary_values.append(10)

max_distances = [] # 树冠中心搜索的最大距离 保证足够的搜索距离
for i in range(T_num): 
    max_distances.append(c_size[i])

lock_thresholds = [] # 范围内锁定像素的阈值，必须大于该值才能视为有效匹配,用来控制重叠区
for i in range(T_num): 
    lock_thresholds.append(int(math.pi * (c_size[i]//2) **2) * glt) # 0意味可以无限重叠，1意味着禁止重叠

# 单独填充最小patch 要求强匹配
if T_num>3:
    suit_map_binary_values[T_num-1] = 10
    circle_masks[T_num-1][:] = 1
    lock_thresholds[T_num-1] = c_size[T_num-1] **2 * mlt

print(lock_thresholds)

from datetime import datetime
current_time = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
write_to_log(log_file_name, current_time)
write_to_log(log_file_name, '\nTmp:'+ str(c_size) +', suit_binaryT:' +str(suit_map_binary_values) +', shadowT:'+str(shandow_threshold)+', minimumT:'+str(minimum_threshold)+', skipT:' + str(skip_threshold)+', maxD:' + str(max_distances) +', lockT:' + str(lock_thresholds) +'\n')

[282.6, 228.6, 180.9, 137.70000000000002, 101.7, 70.2, 45.0, 39.2]


In [391]:
listpic = fnmatch.filter(os.listdir(sr_img_path), '*.tif')

for img_name in tqdm(listpic):
    img_full_path = sr_img_path + '/' + img_name
    # sen_seg_full_path = semantic_segmentation_result_path + '/' + img_name[:-12] + 'result.tif' # 小图
    sen_seg_full_path = semantic_segmentation_result_path + '/' + img_name

    '''读取数据'''
    data_ss, im_proj, im_geotrans = read_img(sen_seg_full_path)[:3]
    data_img = read_img(img_full_path)[0]

    '''将gdal格式转化为opencv格式'''
    data_img = cv2.cvtColor(data_img.transpose(1, 2, 0), cv2.COLOR_RGB2BGR) # (c, h, w) -> (h, w, c) & RGB->BGR

    '''少量特定预处理'''
    height, width, channel = shape(data_img)
    data_ss = cv2.resize(data_ss, (width, height)) # 处理裁剪中的些微偏移

    '''掩膜操作'''
    # 将掩膜中前景的像素设置为1，背景的像素设置为0
    binary_mask = (data_ss == foreground_value).astype(np.uint8)
    # 将原始图像和二值化的掩膜相乘，仅保留前景像素
    masked_img = np.zeros((height, width, channel), dtype=np.uint8)

    for i in range(channel):
        masked_img[:,:,i] = data_img[:,:,i] * binary_mask

    # 新建标记像素
    marker_img = np.zeros((height, width), dtype=np.uint8)
    marker_img = masked_img[:,:,2].copy() # 复制NIR2波段

    T = 235
    marker_img[marker_img < T] = 0
    marker_img[marker_img > T] = 255

    # 去除阴影
    rt = masked_img[:,:,2].copy() # 以第红外波段为基准二值化
    rt[rt<=shandow_threshold] = 0
    rt[rt>shandow_threshold] = 255

    # 构造连通图
    opening = rt.copy() # 230920 取消应用形态学预处理
    markers_com = opening.copy()


    # Find connected components in the binary mask
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(markers_com, connectivity=4)


    ## Create a random color map for visualization
    # 设置随机种子以确保生成的随机数相同
    np.random.seed(9145)
    colorTab = np.zeros((num_labels, 3))

    # 生成0~255之间的随机数
    for i in range(len(colorTab)):
        aa = np.random.uniform(0, 255)
        bb = np.random.uniform(0, 255)
        cc = np.random.uniform(0, 255)
        colorTab[i] = np.array([aa, bb, cc], np.uint8)
    colorTab[0] = [0, 0, 0]

    # Apply the color map to the labels to visualize the connected components
    # 遍历marks每一个元素值，对每一个区域进行颜色填充
    label_vis = np.zeros((height, width, 3), dtype=np.uint8)

    for i in range(height):
        for j in range(width):
            # index值一样的像素表示在一个区域
            index = labels[i][j]
            if index == 0:
                continue
            else:
                label_vis[i][j] = colorTab[index]

    '''胡杨填充'''
    suit_map = np.zeros((T_num, height, width), dtype=np.float64) # 新建适宜度图
    suitmap_for_vis = suit_map.copy()
    minimum_threshold = minimum_threshold # 单株胡杨的最低识别阈值
    populus_center = [] # 储存识别的胡杨中心
    skip_threshold = skip_threshold # 如果有效像素比例低于该阈值则自动跳过

    # print('calculating suitability map...')
    for label in range(1, num_labels):  # 从1开始，跳过背景标签0
        area = stats[label, cv2.CC_STAT_AREA]
        x, y, p_width, p_height = stats[label, cv2.CC_STAT_LEFT], stats[label, cv2.CC_STAT_TOP], stats[label, cv2.CC_STAT_WIDTH], stats[label, cv2.CC_STAT_HEIGHT]
        
        # 处理极小connected component 判断单株胡杨
        if p_width<c_size[0] and p_height<c_size[0]: # 当patch长宽均小于最大模板尺寸时被认为是极小patch
            if p_width > minimum_threshold and p_height > minimum_threshold:
                populus_center.append([y+p_height//2, x+p_width//2, -1])
                continue
            else:
                continue

        # 正常connected component
        # 生成适宜度图
        for k in range(T_num): # k需要和模板数一致
            for i in range(0, p_height):
                for j in range(0, p_width):
                    if y+i+c_size[k] < height and x+j+c_size[k] < width: # 防止越界
                        tmp_labels = labels[y+i: y+i+c_size[k], x+j: x+j+c_size[k]].copy()
                        tmp_labels[tmp_labels!=label] = 0
                        tmp_labels[tmp_labels==label] = 1
                        if cal_patch_march_num(tmp_labels, circle_masks[k]) >= skip_threshold * sum(circle_masks[k]): # 首先判断前景匹配的比例是否足够大
                            tmp_img = data_img[y+i: y+i+c_size[k], x+j: x+j+c_size[k], :].copy() # opencv左上角出发，向右为x轴，向下为y轴
                            average_distance = average_euclidean_distance(Tempalte[k], tmp_img, circle_masks[k]) # 计算模板和小块影像的欧式距离
                            average_distance = 255-average_distance # 使更大的值对应更好的适宜度，此时average_distance取值应在0~255之间
                            suit_map[k, y+i, x+j] = average_distance # 树冠左上角位置为(x,y)
                            suitmap_for_vis[k, y+i+c_size[k]//2, x+j+c_size[k]//2] = average_distance
                        else:
                            continue # 前景比例不足直接跳过，用默认的0作为适宜度最终值，减少计算量
                    else:
                        continue

    ## 处理适宜度图
    # 计算非零像素的值
    nonzero_values = []
    for i in range(T_num):
        tmp_suit_map = suit_map[i,:,:]
        nonzero_value = tmp_suit_map[tmp_suit_map > 0]
        nonzero_values.append(nonzero_value)

    # 计算前x%分位数
    threshold_values = []
    for i in range(T_num):
        if (len(nonzero_values[i]) > 1):
            threshold_values.append(np.percentile(nonzero_values[i], suit_map_binary_values[i])) # 参数二为分位数值 它使得至少有 p% 的数据项小于或等于这个值，且至少有 (100-p)% 的数据项大于或等于这个值。
        else:
            threshold_values.append(1) # 处理无适宜位置的情况

    lock_map = opening.copy() # 锁定图中，值0为锁定像素，值1为未锁定像素 初始状态与形态学处理后的语义分割结果一致
    lock_map[lock_map==255] = 1    

    ## 广度优先搜索
    # 初始化
    directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]

    # print('searching crown centers...')
    for k in range(T_num): # k需要和模板数一致
        for i in range(0, height-c_size[k]):
            for j in range(0, width-c_size[k]):

                visited = np.zeros((height, width), dtype=bool) # 在每次寻找起点时重置访问数组 这种方法会引起大量重复搜索，但适合应用情况
                max_distance = max_distances[k] # 搜索的最大距离，当前设置为模板圆半径
                max_value = -999

                if suit_map[k, i, j] > threshold_values[k] and not visited[i, j] and lock_map[i,j]==1 and cal_lockmap_match_num(lock_map[i: i+c_size[k], j: j+c_size[k]], circle_masks[k]) >= lock_thresholds[k]:
                    # 满足条件才能作为有效搜索起点 1.合适度值大于阈值 2.本身没有被锁定 3.范围内不得有过少有效像素（过多被锁定像素）
                    queue = deque([(i, j, 0)])  # 添加第三个元素表示距离
                    visited[i][j] = True
                    max_value = suit_map[k, i, j]
                    max_coords = (i, j)

                    while queue:
                        h, w, distance = queue.popleft()
                        if distance > max_distance:
                            continue

                        for dh, dw in directions:
                            nh, nw = h + dh, w + dw
                            if 0 <= nh < height and 0 <= nw < width and suit_map[k, nh, nw] > threshold_values[k] and not visited[nh][nw] and lock_map[nh,nw]==1 and cal_lockmap_match_num(lock_map[nh: nh+c_size[k], nw: nw+c_size[k]], circle_masks[k]) >= lock_thresholds[k]:
                                # 满足条件作为下一个有效的搜索值：12.未越界 3.合适度大于阈值 4.没有被访问过 5.本身没有被锁定 6.范围内不得有过少有效像素（过多被锁定像素）
                                queue.append((nh, nw, distance + 1))
                                visited[nh][nw] = True
                                if suit_map[k, nh, nw] > max_value:
                                    max_value = suit_map[k, nh, nw]
                                    max_coords = (nh, nw)
                    
                    # queue为空，获得该次检索的最优位置max_coords
                    for m in range(c_size[k]): # 更新锁定图
                        for n in range(c_size[k]):
                            if circle_masks[k][m][n] == 1:
                                lock_map[max_coords[0]+m, max_coords[1]+n] = 0
                                
                    populus_center.append([max_coords[0]+c_size[k]//2, max_coords[1]+c_size[k]//2, k]) # 添加检索到的胡杨中心     

    '''输出胡杨统计点到shp'''
    '''新建输出shp'''
    os.environ['GDAL_DATA'] = r'C:\Users\75198\.conda\envs\learn\Lib\site-packages\GDAL-2.4.1-py3.6-win-amd64.egg-info\gata-data'  # 防止报error4错误

    gdal.SetConfigOption("GDAL_FILENAME_IS_UTF8", "YES")
    gdal.SetConfigOption("SHAPE_ENCODING", "GBK")

    ogr.RegisterAll()  # 注册所有的驱动

    driver = ogr.GetDriverByName('ESRI Shapefile')

    prj = osr.SpatialReference()
    prj.ImportFromWkt(im_proj)  # 读取栅格数据的投影信息，用来为后面生成的矢量做准备

    '''准备输出'''
    out_shp_full_path = output_shp_path + '/' +img_name[:-4] + '_populus_infer.shp'
    if Path(out_shp_full_path).exists():
        driver.DeleteDataSource(out_shp_full_path)
    out_ds = driver.CreateDataSource(out_shp_full_path)
    out_lyr = out_ds.CreateLayer(
        out_shp_full_path, prj, ogr.wkbPoint)
    def_out_feature = out_lyr.GetLayerDefn()  # 读取feature类型

    oField = ogr.FieldDefn('type', ogr.OFTInteger)
    out_lyr.CreateField(oField)

    '''遍历输出'''
    # print('outputing...')
    for i in range(len(populus_center)):
        point = ogr.Geometry(ogr.wkbPoint)
        x_coord, y_coord = imagexy2geo(im_geotrans, populus_center[i][0], populus_center[i][1])
        write_point_to_layer([x_coord, y_coord, populus_center[i][2]], out_lyr, def_out_feature)
    out_ds.Destroy()

100%|██████████| 43/43 [06:17<00:00,  8.77s/it]
