In [1]:
import os

os.environ["OMP_NUM_THREADS"] = '1'
import cv2
import geopandas as gpd
import numpy as np
from shapely.geometry import LineString, Point, Polygon
from sklearn.cluster import KMeans
import rasterio
from tqdm import tqdm
import torch

In [2]:
class Predictor():

    def __init__(self, model_path):
        self.model = torch.load(model_path).cuda()
        self.model.eval()

    def predict(self, img):
        with torch.no_grad():
            pic_tensor = img.unsqueeze(0).unsqueeze(0).cuda()
            return round(self.model(pic_tensor).item()) + 1

In [3]:
# 1. 文件路径

# input_raster_thin = r"D:\UAV_DATA_NEW\output\4_thin\061410_thin.tif"
input_raster = r"D:\UAV_DATA_NEW\output\2_dilated\061301_dilated.tif"
output_clustering_kmeans_result = r'D:\UAV_DATA_NEW\output\5_results\061301\kmeans\061301_kmeans_pointprocessed.tif'
output_centroid = r'D:\UAV_DATA_NEW\output\5_results\061301\kmeans\061301_kmeans_centroid_pointprocessed.shp'

In [4]:
src = rasterio.open(input_raster)
transform = src.transform

area_per_pixel = abs(transform[0] * transform[4])
interval = 0.62
size = 224
mean_area = 0.313748971
mean_area_pixel = mean_area / area_per_pixel

predictor = Predictor(r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\output\run\2023-05-31_23-28-17\model.pth')


In [5]:
img = src.read(1)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=4)
img_shape = img.shape
img = None

output_img = np.zeros(img_shape, dtype=np.uint8)

n_clusters = np.round(stats[:, 4] / mean_area_pixel).astype(int)
n_clusters[n_clusters == 0] = 1

all_points = []

In [6]:
# KMeans

with tqdm(range(num_labels - 1), desc='Progress') as tbar:
    for i in range(1, num_labels):
        y0, x0, h, w, num_pixels = stats[i]
        x, y = np.where(labels[x0:x0 + w, y0:y0 + h] == i)

        if num_pixels <= 32:
            k = 1
        elif h >= size or w >= size:
            k = round(num_pixels / mean_area_pixel)
        else:
            # 使用ResNet神经网络预测聚类数
            image = torch.zeros((size, size), dtype=torch.float32)
            image[x + (size - w) // 2, y + (size - h) // 2] = 1.0
            k = predictor.predict(image)
            if k < 1:
                k = 1

        x += x0
        y += y0

        # k = n_clusters[i]
        kmeans = KMeans(n_clusters=k, n_init='auto', random_state=0)
        kmeans.fit(np.stack((x, y), axis=1))

        output_img[x, y] = kmeans.labels_ + 1

        cluster_centers = kmeans.cluster_centers_

        all_points.extend(Point(src.xy(*i)) for i in cluster_centers)

        # if k <= 2:
        #     all_points.extend(Point(src.xy(*i)) for i in cluster_centers)
        # else:
        #     centeroids = []
        #     for index, center in enumerate(cluster_centers):
        #         centroid = Point(src.xy(*center))
        #         distances = [centroid.distance(i) for i in centeroids]
        #         if distances:
        #             min_dist = min(distances)
        #             if min_dist > 0.45:
        #                 centeroids.append(centroid)
        #             else:
        #                 point = centeroids.pop(distances.index(min_dist))
        #                 new_x = (point.x + centroid.x) / 2
        #                 new_y = (point.y + centroid.y) / 2
        #                 centeroids.append(Point(new_x, new_y))
        #         else:
        #             centeroids.append(centroid)

        #     all_points.extend(centeroids)

        tbar.update()


Progress: 100%|██████████| 179639/179639 [20:22<00:00, 146.99it/s]


In [7]:
with rasterio.open(output_clustering_kmeans_result, 'w', **src.meta) as dst:
    dst.write(output_img, 1)

In [8]:
geometry = gpd.GeoSeries(all_points, crs=src.crs)
geometry.to_file(output_centroid, driver='ESRI Shapefile')

: 