In [3]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import rasterio
from rasterio.features import shapes
from rasterio.windows import Window
from shapely.geometry import Point, Polygon, shape
from random import random
from pathlib import Path


In [4]:
base_path = Path(r'C:\Users\xianyu\GraduationProject\UAV_YUNNAN_DATA\last_labels_061303')

for i in range(1, 31):
    test_path = base_path / f'test\{i:02d}'
    test_path.mkdir(parents=True, exist_ok=True)
    train_path = base_path / f'train\{i:02d}'
    train_path.mkdir(parents=True, exist_ok=True)


In [5]:
interval = 0.58
mean_area = 0.305


def cut_polygon(geom):
    """
    Cut OBB rectangle into grids
    """
    obb = geom.minimum_rotated_rectangle.exterior.coords
    p1, p2, p3 = Point(obb[0]), Point(obb[1]), Point(obb[2])
    dist1, dist2 = p1.distance(p2), p2.distance(p3)
    if dist1 > dist2:
        cut_num = round(dist1 / interval) + int(dist1 * 2 < interval)
        short_edge = dist2
    else:
        cut_num = round(dist2 / interval) + int(dist2 * 2 < interval)
        short_edge = dist1
    return cut_num, short_edge

In [6]:
src = rasterio.open(r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\data\temp\061303cut\061303_cut_1.tif')

num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(src.read(1), connectivity=4)

size = 224

for i in range(1, num_labels):
    y0, x0, h, w, num_pixels = stats[i]

    if h > size or w > size:
        continue
    if num_pixels < 32:
        continue

    transform = src.window_transform(Window(y0, x0, h, w))

    xs, ys = np.where(labels[x0:x0 + w, y0:y0 + h] == i)
    xs += (size - w) // 2
    ys += (size - h) // 2
    output = np.zeros((size, size), dtype=np.uint8)
    output[xs, ys] = 255
    shape_results = shapes(output, transform=transform, mask=output)

    for string, value in shape_results:
        polygon = shape(string)
        cut_num, short_edge = cut_polygon(polygon)

        if short_edge >= 1.2:
            cut_num = round(polygon.area / mean_area)
        rand = random()

        if cut_num < 6 or cut_num > 28:
            continue

        if rand > 0.8:
            path = base_path / f'test/{cut_num:02d}\{i}.png'
        else:
            path = base_path / f'train/{cut_num:02d}\{i}.png'

        cv2.imwrite(str(path), output)

In [None]:
# src = rasterio.open(r"D:\UAV_DATA_NEW\output\2_dilated\061301_dilated.tif")
# window = Window(7000, 3000, 2000, 4000)
# data = src.read(1, window=window)

# with rasterio.open(r'C:\Users\xianyu\GraduationProject\tobacco_plant_count\data\temp\cut.tif',
#                    'w',
#                    driver='GTiff',
#                    width=2000,
#                    height=4000,
#                    count=1,
#                    dtype=rasterio.uint8,
#                    transform=src.window_transform(window),
#                    crs=src.crs) as dst:
#     dst.write(data, 1)

In [None]:
# import torch
# import numpy as np
# import time

# a = torch.arange(0, 20000 * 20000).reshape(20000, 20000)

# time1 = time.time()
# for i in range(30):
#     a * a + 1
# time2 = time.time()
# print(time2 - time1)

# a = a.cuda()

# time1 = time.time()
# for i in range(30):
#     a * a + 1
# time2 = time.time()
# print(time2 - time1)

# a = np.arange(0, 20000 * 20000).reshape(20000, 20000)

# time1 = time.time()
# for i in range(30):
#     a * a + 1
# time2 = time.time()
# print(time2 - time1)