In [1]:
import os
import numpy as np
import pandas as pd
from osgeo import gdal
from skimage import io

  from .collection import imread_collection_wrapper


### Utility functions

In [2]:
def writeTif(img, out_dir, transform=None, proj=None):
    """ write the geotiff image to the output directory

    Args:
        img (ndarray): the image to be written
        out_dir (string): the output directory 
        transform (_type_, optional): the . Defaults to None.
        proj (_type_, optional): _description_. Defaults to None.

    Returns:
        _type_: _description_
    """

    if img is None or img.__len__() == 0:
        return
    else:
        band1 = img[0]
        img_width = band1.shape[1]
        img_height = band1.shape[0]
        num_bands = img.__len__()

        if 'int8' in band1.dtype.name:
            datatype = gdal.GDT_Byte
        elif 'int16' in band1.dtype.name:
            datatype = gdal.GDT_UInt16
        else:
            datatype = gdal.GDT_Float32

        driver = gdal.GetDriverByName("GTiff")
        dataset = driver.Create(out_dir, img_width, img_height, num_bands, datatype)
        if dataset is not None:
            if transform and proj:
                dataset.SetGeoTransform(transform)  # 写入仿射变换参数
                dataset.SetProjection(proj)  # 写入投影
            for i in range(img.__len__()):
                dataset.GetRasterBand(i + 1).WriteArray(img[i])
        print("save image success.")
    
def match_image_size(source_dir, target_dir, out_dir):
    source_img = gdal.Open(source_dir)
    target_img = gdal.Open(target_dir)
    crs_base = source_img
    img_geotrans = crs_base.GetGeoTransform()
    img_proj = crs_base.GetProjection()
    XSize_new, YSize_new = min([source_img.RasterXSize, target_img.RasterXSize]), min([source_img.RasterYSize, target_img.RasterYSize])
    source_array = source_img.ReadAsArray(0, 0, source_img.RasterXSize, source_img.RasterYSize)
    target_array = target_img.ReadAsArray(0, 0, target_img.RasterXSize, target_img.RasterYSize)[np.newaxis, :, :]
    writeTif(source_array[:, :YSize_new, :XSize_new], os.path.join(out_dir, 'region1_matched.tif'), img_geotrans, img_proj)
    writeTif(target_array[:, :YSize_new, :XSize_new], os.path.join(out_dir, 'region1_label_matched.tif'), img_geotrans, img_proj)



### Generate sample-label pair for individual tree mapping

In [5]:
def random_segmentation(in_dir_img, out_dir, x_range, y_range, num, require_proj=False, in_dir_target=None):
    """_summary_

    Args:
        in_dir (str): _description_
        out_dir (str): _description_
        x_range (int): _description_
        y_range (int): _description_
        num (int): _description_
        target (bool, optional): _description_. Defaults to False.
        require_proj (bool, optional): _description_. Defaults to False.
    """
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    if not os.path.exists(out_dir.replace('img', 'target')):
        os.makedirs(out_dir.replace('img', 'target'))
    img = gdal.Open(in_dir_img)
    if require_proj:
        img_geotrans = img.GetGeoTransform()
        img_proj = img.GetProjection()
        top_left_x = img_geotrans[0]
        w_e_pixel_resolution = img_geotrans[1]
        top_left_y = img_geotrans[3]
        n_s_pixel_resolution = img_geotrans[5]

    x_num = img.RasterXSize // x_range
    y_num = img.RasterYSize // y_range
    x_size, y_size, x_off, y_off = img.RasterXSize, img.RasterYSize, 0, 0
    img_array = img.ReadAsArray(x_off, y_off, x_size, y_size)

    ## generate pixels coordinates
    np.random.seed(1111)
    radius = x_range // 2
    x_coor = np.random.choice(range(radius+1, x_size - radius - 1), num)
    y_coor = np.random.choice(range(radius+1, y_size - radius - 1), num)

    for i in range(0, num):
        patch = img_array[:, y_coor[i] - radius:y_coor[i] + radius,
                x_coor[i] - radius:x_coor[i] + radius]
        if patch.dtype == 'uint8':
            patch[patch == 255] = 0
        if patch.dtype == 'uint16':
            patch[patch == 65535] = 0
        if (patch == 0).mean() > 0.05:
            continue
        new_top_left_x = top_left_x + (x_coor[i] - radius) * np.abs(w_e_pixel_resolution)
        new_top_left_y = top_left_y - (y_coor[i] - radius) * np.abs(n_s_pixel_resolution)
        dst_transform = (
        new_top_left_x, img_geotrans[1], img_geotrans[2], new_top_left_y, img_geotrans[4],
        img_geotrans[5])
        patch_name = os.path.join(out_dir, str(x_coor[i]) + '_' + str(y_coor[i]) + '_0.tif')
        writeTif(patch, patch_name, transform=dst_transform, proj=img_proj)
        if in_dir_target:
            target = gdal.Open(in_dir_target)
            target_array = target.ReadAsArray(0, 0, x_size, y_size)
            target_array[target_array > 1] = 0
            patch = target_array[y_coor[i] - radius:y_coor[i] + radius,
                    x_coor[i] - radius:x_coor[i] + radius][np.newaxis, :, :]
            patch_name = os.path.join(out_dir.replace('img', 'target'), str(x_coor[i]) + '_' + str(y_coor[i]) + '_0.tif')
            writeTif(patch, patch_name, transform=dst_transform, proj=img_proj)    
    
for idx in ['5']:
    random_segmentation(in_dir_img=fr"F:\DigitalAG\morocco\unet\data\img\region{idx}_gr.tif",
                        in_dir_target=fr"F:\DigitalAG\morocco\unet\data\label\region{idx}_label_gr.tif", 
                        # out_dir=r"F:\DigitalAG\morocco\unet\pretrain\\160\validation\img",
                        out_dir=r"F:\DigitalAG\sam\img",
                        x_range=1000, 
                        y_range=1000,
                        num=10,
                        require_proj=True)
# for item in os.listdir("Z:\Morocco\second_data\georeferenced"):
#     if item.endswith(".tif"):
#         random_segmentation(in_dir_img=os.path.join("Z:\Morocco\second_data\georeferenced", item),
#                                 in_dir_target=None,
#                                 out_dir="Z:\Morocco\patch",
#                                 x_range=32, 
#                                 y_range=32,
#                                 num=800,
#                                 require_proj=True)

save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.


### generate visualization for ViT label

In [6]:
def scale_percentile_n(matrix):
    matrix = matrix.transpose(2, 0, 1).astype(np.float)
    d, w, h = matrix.shape
    for i in range(d):
        if matrix[i].mean() == 0:
            continue
        mins = np.percentile(matrix[i][matrix[i] != 0], 1)
        maxs = np.percentile(matrix[i], 99)
        matrix[i] = matrix[i].clip(mins, maxs)
        matrix[i] = ((matrix[i] - mins) / (maxs - mins))
    return matrix

def generate_visualization(in_dir, out_dir):
    import shutil
    if os.path.exists(os.path.join(out_dir)):
        shutil.rmtree(os.path.join(out_dir))
    os.makedirs(os.path.join(out_dir))
    for file in os.listdir(in_dir):
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        image = scale_percentile_n(io.imread(os.path.join(in_dir, file))).transpose(1,2,0)
        # target = io.imread(os.path.join(in_dir.replace("img", "target"), file))
        # if target.mean() > 0.1:
        #     filename = file.replace('tif', 'jpg').replace('_0', '_1')
        # else:
        #     filename = file.replace('tif', 'jpg')
        filename = file.replace('tif', 'jpg')
        io.imsave(os.path.join(out_dir, filename), image)


# generate_visualization(in_dir=r"F:\DigitalAG\morocco\unet\baseline\\32\validation\img",
#                        out_dir=r"F:\DigitalAG\morocco\unet\baseline\\32\validation\visualization")
generate_visualization(in_dir=r"F:\DigitalAG\sam\img",
                       out_dir=r"F:\DigitalAG\sam\visualization")



### Generate testing img for individual tree mapping

In [3]:
def non_overlap_segmentation(in_dir_img, out_dir, mode,
                             require_proj=False, check_valid=False, 
                             pixel=None, num=None, invalid_value=None, invalid_thres=None, in_dir_target=None):
    '''
    Parameters
    ----------
    in_dir: str
        the input directory that contains the image
    out_dir: str
        the directory to save the output

    '''
    if not os.path.exists(out_dir):
            os.makedirs(out_dir)
    img = gdal.Open(in_dir_img)
    if require_proj:
        img_geotrans = img.GetGeoTransform()
        img_proj = img.GetProjection()
        top_left_x = img_geotrans[0]
        w_e_pixel_resolution = img_geotrans[1]
        top_left_y = img_geotrans[3]
        n_s_pixel_resolution = img_geotrans[5]

    x_size, y_size, x_off, y_off = img.RasterXSize, img.RasterYSize, 0, 0
    img_array = img.ReadAsArray(x_off, y_off, x_size, y_size)

    if mode == 'by_pixel':
        x_num = img.RasterXSize // pixel[0]
        y_num = img.RasterYSize // pixel[0]
        for i in range(0, x_num): # column
            for j in range(0, y_num): # row
                x_off_patch = i * pixel[0]
                y_off_patch = j * pixel[1]
                patch = img_array[:, y_off_patch:y_off_patch + pixel[1], x_off_patch:x_off_patch + pixel[0]]
                if patch.dtype == 'uint8':
                    patch[patch == 255] = 0
                if patch.dtype == 'uint16':
                    patch[patch == 65535] = 0
                if (patch == 0).mean() > 0.05:
                    continue
                new_top_left_x = top_left_x + x_off_patch * np.abs(w_e_pixel_resolution)
                new_top_left_y = top_left_y - y_off_patch * np.abs(n_s_pixel_resolution)
                dst_transform = (
                                new_top_left_x, img_geotrans[1], img_geotrans[2], new_top_left_y, img_geotrans[4],
                                img_geotrans[5])
                patch_name = os.path.join(out_dir, str(i * y_num + j) + '_0.tif')
                writeTif(patch, patch_name, transform=dst_transform, proj=img_proj)
                if in_dir_target:
                    target = gdal.Open(in_dir_target)
                    target_array = target.ReadAsArray(0, 0, x_size, y_size)
                    target_array[target_array > 1] = 0
                    patch = target_array[y_off_patch:y_off_patch + pixel[1], x_off_patch:x_off_patch + pixel[0]][np.newaxis, :, :]
                    if not os.path.exists(out_dir.replace('img', 'target')):
                        os.makedirs(out_dir.replace('img', 'target'))
                    patch_name = os.path.join(out_dir.replace('img', 'target'), str(i * y_num + j) + '_0.tif')
                    writeTif(patch, patch_name, transform=dst_transform, proj=img_proj)    

    if mode == 'by_num':
        x_range = img.RasterXSize // num[0]
        y_range = img.RasterYSize // num[1]
        for i in range(0, num[0]): # column
            for j in range(0, num[1]): # row
                x_off_patch = i * x_range
                y_off_patch = j * y_range
                patch = img_array[:, y_off_patch:y_off_patch + y_range, x_off_patch:x_off_patch + x_range]
                if check_valid:
                    if (patch != invalid_value).mean() < invalid_thres:
                        continue
                patch_name = os.path.join(out_dir, c+ '_'+str(i * num[1] + j) + '.tif')

                if require_proj:
                    new_top_left_x = top_left_x + x_off_patch * np.abs(w_e_pixel_resolution)
                    new_top_left_y = top_left_y - y_off_patch * np.abs(n_s_pixel_resolution)
                    dst_transform = (
                        new_top_left_x, img_geotrans[1], img_geotrans[2], new_top_left_y, img_geotrans[4], img_geotrans[5])
                    writeTif(patch, patch_name, dst_transform, img_proj)
                else:
                    writeTif(patch, patch_name)

##################################
### generate ima/tgt pari for all images
##################################
for img in os.listdir(r"Z:\Morocco\second_data\georeferenced"):
    if img.endswith("tif"):
        if not os.path.exists(r"Z:\Morocco\patch\\" + img.split(".")[0]):
            os.makedirs(r"Z:\Morocco\patch\\" + img.split(".")[0])
        non_overlap_segmentation(in_dir_img=os.path.join(r"Z:\Morocco\second_data\georeferenced", img),
                                #  in_dir_target=r"F:\DigitalAG\morocco\unet\data\label\region5_label_gr.tif",
                                out_dir=r"Z:\Morocco\patch\\" + img.split(".")[0],
                                mode='by_pixel',
                                pixel=[224, 224], 
                                require_proj=True)

##################################
### generate ima/tgt pari for any specific image
##################################
# img_dir = r"Z:\Morocco\georeference_task\qualified\\19FEB13111500-M1BS-505246646070_01_P002_gr.tif"
# non_overlap_segmentation(in_dir_img=img_dir,
#                         #  in_dir_target=r"F:\DigitalAG\morocco\unet\data\label\region6_label_gr.tif",
#                          out_dir=r"F:\DigitalAG\morocco\unet\baseline\\32\testing\region5\img",
#                          mode='by_pixel',
#                          pixel=[32, 32], 
#                          require_proj=True)

# generate_visualization(in_dir=r"F:\DigitalAG\morocco\unet\baseline\\32\testing\region5\img",
#                        out_dir=r"F:\DigitalAG\morocco\unet\baseline\\32\testing\region5\visualization")

save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.


### Make smaller dataset

In [None]:
from shutil import copy
in_dir = r"F:\DigitalAG\morocco\unet\testing\region2"
out_dir = r"F:\DigitalAG\morocco\unet\testing\region2_reduced"
for type in ['csv', 'img', 'target', 'visualization']:
    if not os.path.exists(os.path.join(out_dir, type)):
        os.makedirs(os.path.join(out_dir, type))
img_list = np.array(os.listdir(os.path.join(in_dir, 'img')))
candidate = img_list[np.random.choice(len(img_list), 300)]
for c in candidate:
    copy(os.path.join(in_dir, 'img', c), os.path.join(out_dir, 'img', c))
    copy(os.path.join(in_dir, 'target', c), os.path.join(out_dir, 'target', c))
    copy(os.path.join(in_dir, 'csv', c.replace('tif', 'csv')), os.path.join(out_dir, 'csv', c.replace('tif', 'csv')))
    # copy(os.path.join(in_dir, 'visualization', c.replace('tif', 'jpg')), os.path.join(out_dir, 'img', c.replace('tif', 'jpg')))



### Generate label csv

In [28]:
def create_label(in_dir, out_dir):
    if os.path.exists(out_dir):
        os.remove(out_dir)
    df_array = []
    for item in os.listdir(in_dir):
        img_idx = '_'.join(item.split('.')[0].split('_')[:-1])
        img_lb = item.split('.')[0].split('_')[-1]
        df_array.append(np.array([img_idx, img_lb]))
    df = pd.DataFrame(data=df_array, columns=['img', 'label'], index=np.array(df_array)[:, 0])
    df.to_csv(out_dir)

dataset = r'baseline\\32\validation'
create_label(in_dir=fr"F:\DigitalAG\morocco\unet\{dataset}\visualization",
            out_dir=fr"F:\DigitalAG\morocco\unet\{dataset}\label.csv")

### match folders

In [27]:
dst_folder_list = ['img', 'target']
src_folder = 'visualization'
root_dir = r"F:\DigitalAG\morocco\unet\baseline\\32\validation"
src_img_list = []
for item in os.listdir(os.path.join(root_dir, src_folder)):
    img_name_with_label = item.split('.')[0] # will get something like "115_139_0", "115_139" is the x/y coordinates and 0 is the label
    img_name_no_label = '_'.join(img_name_with_label.split('_')[:-1]) # will get "115_139"
    src_img_list.append(f"{img_name_no_label}_0.tif")
for f in dst_folder_list:
    dst_img_list = np.array(os.listdir(os.path.join(root_dir, f)))
    src_img_list = np.array(src_img_list)
    union = np.intersect1d(dst_img_list, src_img_list)
    diff_dst = np.setdiff1d(dst_img_list, union)
    [os.remove(os.path.join(root_dir, f, file)) for file in diff_dst]

In [55]:
candidate_tiles = pd.read_csv(r"Z:\Morocco\small_tiles.csv")
output_tiles = os.listdir("Z:\Morocco\georeference_task\qualified")
output_tiles = [o.replace('_gr.tif', '') for o in output_tiles]
for i in range(len(candidate_tiles)):
    if candidate_tiles.iloc[i]['id'] != 'none':
        if not candidate_tiles.iloc[i]['tile'] in output_tiles:
            print (candidate_tiles.iloc[i]['tile'])

18DEC12113125-M1BS-505246646030_01_P004
19JUL14112204-M1BS-505246645090_01_P005
19NOV09111258-M1BS-505246646080_01_P004
20JAN13111723-M1BS-505246645020_01_P002


### Generate subpatch for main patch

In [98]:
def generate_visualization_subpatch(in_dir, out_dir):
    import shutil
    if os.path.exists(os.path.join(out_dir)):
        shutil.rmtree(os.path.join(out_dir))
    os.makedirs(os.path.join(out_dir))
    np.random.seed(200)
    for file in os.listdir(in_dir):
        if not os.path.exists(out_dir):
            os.makedirs(out_dir)
        patch = io.imread(os.path.join(in_dir, file))
        x_coor = np.random.choice(7, 1)
        y_coor = np.random.choice(7, 1)
        for idx in range(1):
            sub_patch = patch[x_coor[idx]*32:(x_coor[idx]+1)*32, y_coor[idx]*32:(y_coor[idx]+1)*32, :]
            patch_idx = '_'.join(file.split('.')[0].split('_')[:2])
            patch_class = file.split('.')[0].split('_')[-1]
            filename = patch_idx + '_' + str(x_coor[idx]*7 + y_coor[idx]) + '_' + patch_class + '.jpg'
            io.imsave(os.path.join(out_dir, filename), sub_patch)


generate_visualization_subpatch(in_dir=r"F:\DigitalAG\morocco\unet\pretrain\\training\visualization",
                       out_dir=r"F:\DigitalAG\morocco\unet\pretrain\\training\visualization_patch_3")



In [99]:
def create_label_patch(in_dir, out_dir):
    if os.path.exists(out_dir):
        os.remove(out_dir)
    df_array = []
    for item in os.listdir(in_dir):
        img_idx = '_'.join(item.split('.')[0].split('_')[:2])
        patch_idx = item.split('.')[0].split('_')[2]
        img_lb = item.split('.')[0].split('_')[-1]
        df_array.append(np.array([img_idx, patch_idx, img_lb]))
    df = pd.DataFrame(data=df_array, columns=['img', 'patch', 'label'], index=np.array(df_array)[:, 0])
    df.to_csv(out_dir)

dataset = r'pretrain\training'
create_label_patch(in_dir=fr"F:\DigitalAG\morocco\unet\{dataset}\visualization_patch",
            out_dir=fr"F:\DigitalAG\morocco\unet\{dataset}\label_patch.csv")

### Prepare data for GCP

In [2]:
in_dir = r"Z:\Morocco\results\vit"
out_dir = r"Z:\Morocco\results\vit\gcp"
import shutil
for f in os.listdir(in_dir):
    if f != "gcp":
        img_dir = os.path.join(in_dir, f, "mosaic")
        if not os.path.exists(img_dir):
            print (img_dir)
            continue
        for img in os.listdir(img_dir):
            # if img.split('.')[0].split('_')[-1] == "post":
            shutil.copy(os.path.join(img_dir, img), os.path.join(out_dir, f+'.tif'))
            os.remove(os.path.join(img_dir, img))


### Aggregate classification results

In [42]:
def aggre(in_dir, img_range=64):
    img = gdal.Open(in_dir)
    img_geotrans = img.GetGeoTransform()
    img_proj = img.GetProjection()
    rows = img.RasterYSize
    cols = img.RasterXSize
    img_prob = img.ReadAsArray(0, 0, cols, rows)
    new_prob = img_prob.copy()
    assert len(img_prob.shape) == 2, 'please make sure the input is 2D binary mask'
    col_num = cols // img_range
    row_num = rows // img_range
    for row in range(1, row_num - 1):
        for col in range(1, col_num - 1):
            aggre_prob = img_prob[(row-1)*img_range:row*img_range, (col-1)*img_range:col*img_range].mean()
            new_prob[(row-1)*img_range:row*img_range, (col-1)*img_range:col*img_range] = aggre_prob
    writeTif(new_prob[np.newaxis,:,:], in_dir.replace('gcp', f'gcp{img_range}'),
    transform=img_geotrans, proj=img_proj)
for item in os.listdir(r"Z:\Morocco\results\vit\gcp"):
    if item.endswith('tif'):
        aggre(os.path.join(r"Z:\Morocco\results\vit\gcp", item), 64)

save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.
save image success.


KeyboardInterrupt: 