In [1]:
import os

os.environ["OMP_NUM_THREADS"] = '1'
from pathlib import Path

import cv2
import geopandas as gpd
import numpy as np
import rasterio
import torch
from shapely.geometry import Point
from sklearn.cluster import KMeans, SpectralClustering
from tqdm import tqdm

In [2]:
# 1. 文件路径
# input_raster = r"D:\UAV_DATA_NEW\output\2_dilated\061301_dilated.tif"
# output_clustering_spectr_result = r'D:\UAV_DATA_NEW\output\5_result_pixels\061301_spec.tif'
# output_centroid = r'D:\UAV_DATA_NEW\output\5_result_centroids\061301_spec_centroid.shp'

# input_raster = r"C:\Users\xianyu\GraduationProject\tobacco_plant_count\data\temp\cut.tif"
# output_clustering_spectr_result = r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\data\temp\cut_spec.tif'
# output_centroid = r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\data\temp\cut_spec_centroid.shp'


In [3]:
class Predictor():

    def __init__(self, model_path):
        self.model = torch.load(model_path).cuda()
        self.model.eval()

    def predict(self, img):
        with torch.no_grad():
            pic_tensor = img.unsqueeze(0).unsqueeze(0).cuda()
            return round(self.model(pic_tensor).item()) + 1




In [4]:
interval = 0.59
mean_area = 0.305
size = 224
predictor = Predictor(r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\output\run\2023-05-24_01-04-53\model.pth')

# def cut_polygon(geom):
#     """
#     Cut OBB rectangle into grids
#     """
#     obb = geom.minimum_rotated_rectangle.exterior.coords
#     p1, p2, p3 = Point(obb[0]), Point(obb[1]), Point(obb[2])
#     dist1, dist2 = p1.distance(p2), p2.distance(p3)
#     if dist1 > dist2:
#         cut_num = round(dist1 / interval) + int(dist1 * 2 < interval)
#         short_edge = dist2
#     else:
#         cut_num = round(dist2 / interval) + int(dist2 * 2 < interval)
#         short_edge = dist1
#     return cut_num, short_edge

In [5]:
def run(input_raster, output_spectral_result, output_centroid):
    src = rasterio.open(input_raster)

    transform = src.transform
    area_per_pixel = abs(transform[0] * transform[4])
    mean_area = 0.305
    mean_area_pixel = mean_area / area_per_pixel

    img = src.read(1)
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(img, connectivity=4)
    img_shape = img.shape
    img = None

    labels = torch.tensor(labels)

    output_img = torch.zeros(img_shape, dtype=torch.uint8)

    all_points = []

    with tqdm(range(num_labels - 1), desc='Progress') as tbar:
        for i in range(1, num_labels):
            # 获取每个连通区域的信息
            y0, x0, h, w, num_pixels = stats[i]
            x, y = torch.where(labels[x0:x0 + w, y0:y0 + h] == i)

            if num_pixels <= 32:
                k = 1
            elif h >= size or w >= size:
                k = round(num_pixels / mean_area_pixel)
            else:
                # 使用ResNet神经网络预测聚类数
                image = torch.zeros((size, size), dtype=torch.float32)
                image[x + (size - w) // 2, y + (size - h) // 2] = 1.0
                k = predictor.predict(image)

            x += x0
            y += y0

            if k == 1:
                output_img[x, y] = 1
                x_mean = x.float().mean()
                y_mean = y.float().mean()
                all_points.append(src.xy(x_mean, y_mean))  # 默认 offset='center'
            else:
                coords = torch.stack((x, y), dim=1)
                cluster = SpectralClustering(n_clusters=k, affinity='rbf').fit(coords)

                cluster_labels = torch.tensor(cluster.labels_, dtype=torch.uint8)
                output_img[x, y] = cluster_labels + 1

                for j in range(k):
                    filter = torch.where(cluster_labels == j)
                    x_mean = x[filter].float().mean()
                    y_mean = y[filter].float().mean()
                    all_points.append(src.xy(x_mean, y_mean))

            tbar.update()

    with rasterio.open(output_spectral_result, 'w', **src.meta) as dst:
        dst.write(output_img.to(torch.uint8), 1)

    geometry = gpd.GeoSeries(Point(i) for i in all_points).set_crs(src.crs)
    geometry.to_file(output_centroid, driver='ESRI Shapefile')

In [6]:
# 设置目录路径和文件后缀
input_path = Path(r'D:\UAV_DATA_NEW\output\2_dilated')
output_path = Path(r'D:\UAV_DATA_NEW\output\5_results')

# 循环处理每个文件
for input_raster in input_path.glob(f'*.tif'):
    file_name = input_raster.stem[0:6]
    subfolder = output_path / file_name
    output_raster = subfolder / f'{file_name}_spec_classes_resnext50_weighted_ep50.tif'
    output_shapefile = subfolder / f'{file_name}_spec_centroid_resnext50_weighted_ep50.shp'
    
    if output_raster.exists() and output_shapefile.exists():
        continue
    subfolder.mkdir(exist_ok=True, parents=True)
    print(output_raster.name)
    run(input_raster, output_raster, output_shapefile)

061302_spec_classes_resnext50_weighted_ep50.tif


Progress: 100%|██████████| 170858/170858 [2:52:28<00:00, 16.51it/s]   


061303_spec_classes_resnext50_weighted_ep50.tif


Progress:  17%|█▋        | 65507/378053 [1:02:41<3:02:13, 28.59it/s] 

In [None]:
# 空间谱聚类 -cpu-面积纯聚类
img = src.read(1)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=4)
img_shape = img.shape
img = None

labels = torch.tensor(labels)

output_img = torch.zeros(img_shape, dtype=torch.uint8)

n_clusters = np.round(stats[:, 4] / mean_area_pixel).astype(int)
n_clusters[n_clusters == 0] = 1

all_points = []

with tqdm(range(num_labels - 1), desc='Progress') as tbar:
    for i in range(1, num_labels):

        y0, x0, h, w, _ = stats[i]
        x, y = torch.where(labels[x0:x0 + w, y0:y0 + h] == i)
        x += x0
        y += y0

        k = n_clusters[i]

        if k <= 1:
            output_img[x, y] = 1
            x_mean = x.float().mean()
            y_mean = y.float().mean()
            all_points.append(src.xy(x_mean, y_mean))  # 默认 offset='center'
        else:
            coords = torch.stack((x, y), dim=1)
            cluster = SpectralClustering(n_clusters=k, affinity='rbf').fit(coords)

            cluster_labels = torch.tensor(cluster.labels_, dtype=torch.uint8)
            output_img[x, y] = cluster_labels + 1

            for j in range(k):
                filter = torch.where(cluster_labels == j)
                x_mean = x[filter].float().mean()
                y_mean = y[filter].float().mean()
                all_points.append(src.xy(x_mean, y_mean))

        tbar.update()

with rasterio.open(output_clustering_spectr_result, 'w', **src.meta) as dst:
    dst.write(output_img.to(torch.uint8), 1)

geometry = gpd.GeoSeries(Point(i) for i in all_points).set_crs(src.crs)
geometry.to_file(output_centroid, driver='ESRI Shapefile')

In [None]:
# 空间谱聚类 -cpu-面积纯聚类-处理点过多的情况
src = rasterio.open(input_raster)
img = src.read(1)
num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(img, connectivity=4)

labels = torch.tensor(labels)
img = torch.tensor(img)

output_img = torch.zeros_like(img, dtype=torch.uint8)

n_clusters = np.round(stats[:, 4] / mean_area_pixel).astype(int)
n_clusters[n_clusters == 0] = 1

all_points = []

with tqdm(range(num_labels - 1), desc='Progress') as tbar:
    for i in range(1, num_labels):

        y0, x0, h, w, _ = stats[i]
        x, y = torch.where(labels[x0:x0 + w, y0:y0 + h] == i)
        x += x0
        y += y0

        k = n_clusters[i]

        if k <= 1:
            output_img[x, y] = 1
            x_mean = x.float().mean()
            y_mean = y.float().mean()
            point = src.xy(x_mean, y_mean)
            all_points.append(Point(point))  # 默认 offset='center'
        else:
            coords = torch.stack((x, y), dim=1)
            cluster = SpectralClustering(n_clusters=k, affinity='rbf')
            cluster.fit(coords)
            cluster_labels = torch.tensor(cluster.labels_, dtype=torch.uint8)
            output_img[x, y] = cluster_labels + 1

            centeroids = []
            for j in range(k):
                filter = torch.where(cluster_labels == j)
                x_mean = x[filter].float().mean()
                y_mean = y[filter].float().mean()

                centroid = Point(src.xy(x_mean, y_mean))
                distances = [centroid.distance(i) for i in centeroids]
                if distances:
                    min_dist = min(distances)
                    if min_dist > 0.45:
                        centeroids.append(centroid)
                    else:

                        point = centeroids.pop(distances.index(min_dist))
                        new_x = (point.x + centroid.x) / 2
                        new_y = (point.y + centroid.y) / 2
                        centeroids.append(Point(new_x, new_y))
                else:
                    centeroids.append(centroid)

            all_points.extend(centeroids)
        tbar.update()

with rasterio.open(output_clustering_spectr_result, 'w', **src.meta) as dst:
    dst.write(output_img.to(torch.uint8), 1)

geometry = gpd.GeoSeries(all_points).set_crs(src.crs)
geometry.to_file(output_centroid, driver='ESRI Shapefile')


In [None]:
# import geopandas as gpd
# from rtree import index
# from shapely.geometry import Point

# # 假设你有以下的 GeoDataFrame
# gdf = gpd.GeoDataFrame(geometry=[Point(1, 1), Point(2, 2), Point(1, 2), Point(3, 3), Point(4, 4), Point(5, 5)])

# # 创建一个空的 R-tree 索引
# idx = index.Index()

# # 填充 R-tree 索引
# for i, geom in enumerate(gdf.geometry):
#     idx.insert(i, geom.bounds)

# # 定义一个阈值，小于该阈值的点对会被选出
# threshold = 1.5

# # 查询所有距离小于阈值的点对
# pairs = set()
# for i in range(len(gdf)):
#     geom = gdf.geometry[i]
#     possible_matches_index = list(idx.nearest((geom.x, geom.y), num_results=4))  # 获得可能的匹配
#     possible_matches = gdf.iloc[possible_matches_index]
#     precise_matches = possible_matches[possible_matches.distance(geom) < threshold]

#     for j, match in precise_matches.iterrows():
#         if i != j:
#             # 使用 frozenset 可以确保 (a, b) 和 (b, a) 被视为同一对
#             pair = frozenset([geom, match.geometry])
#             pairs.add(pair)

# pairs

In [None]:
# KMeans
with tqdm(range(num_labels - 1), desc='Progress') as tbar:
    for i in range(1, num_labels):
        y0, x0, h, w, _ = stats[i]
        x, y = torch.where(labels[x0:x0 + w, y0:y0 + h] == i)
        x += x0
        y += y0

        coords = torch.stack((x, y), dim=1)
        k = n_clusters[i]
        kmeans = KMeans(n_clusters=k, n_init='auto')
        kmeans.fit(coords)

        output_img[x, y] = torch.tensor(kmeans.labels_, dtype=torch.uint8) + 1

        cluster_centers = kmeans.cluster_centers_

        # if k > 2:
        #     all_points.extend(Point(src.xy(*i)) for i in cluster_centers)
        # else:
        #     centeroids = []
        #     for j in range(k):
        #         filter = torch.where(kmeans.cluster_centers_ == j)

        #         centroid = Point(src.xy(x_mean, y_mean))
        #         distances = [centroid.distance(i) for i in centeroids]
        #         if distances:
        #             min_dist = min(distances)
        #             if min_dist > 0.45:
        #                 centeroids.append(centroid)
        #             else:
        #                 point = centeroids.pop(distances.index(min_dist))
        #                 new_x = (point.x + centroid.x) / 2
        #                 new_y = (point.y + centroid.y) / 2
        #                 centeroids.append(Point(new_x, new_y))
        #         else:
        #             centeroids.append(centroid)

        #     all_points.extend(centeroids)

        tbar.update()

with rasterio.open(output_clustering_kmeans_result, 'w', **src.meta) as dst:
    dst.write(output_img.astype(rasterio.uint8), 1)

all_centroids = np.vstack(all_points)

geometry = gpd.GeoSeries(Point(i) for i in all_centroids).set_crs(src.crs)
geometry.to_file(output_centroid, driver='ESRI Shapefile')

In [None]:
import time

# 测量使用 + 运算符的时间
list1 = [1, 2, 3]
result = list1
start_time = time.time()
for i in range(100000):
    result = result + [i for i in range(10)]
end_time = time.time()
print("使用 + 运算符的时间：", end_time - start_time)

# 测量使用 extend 方法的时间
list1 = [1, 2, 3]
start_time = time.time()
for i in range(100000):
    list1.extend(i for i in range(10))
end_time = time.time()
print("使用 extend 方法的时间：", end_time - start_time)