In [1]:
#!/.conda/envs/dp python
# -*- coding: utf-8 -*-

"""
8位影像标签重映射
比如原标签是1 10 20 30, 可以自动映射为0 1 2 3
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
wanghaoyu191@mails.ucas.ac.cn
"""
from pathlib import Path
import gdal
import os
import ogr
import osr
import sys
import math
from osgeo.ogr import Geometry, Layer
from tqdm import tqdm
import numpy as np
import fnmatch
import copy

def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img

    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data

    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img

    Args:
        sr_img: The full path of the original image

    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans

sr_label_path = r'E:\YJS\0-code\3-data\2-enhance_label' # 原始标签文件夹
output_label_path = r'E:\YJS\0-code\3-data\2-enhance_relabel' # 重映射后的标签文件夹

listlabel = fnmatch.filter(os.listdir(sr_label_path), '*.tif')

uni_val=set()
for label in tqdm(listlabel):
    label_full_path = sr_label_path + '/' + label
    '''读取label数据'''
    data, proj_temp, geotrans_temp = read_img(label_full_path)
    uni_val.update(list(np.unique(data)))


100%|██████████| 7842/7842 [01:22<00:00, 95.33it/s] 


In [4]:
print("classnum",len(uni_val))
for label in tqdm(listlabel):
    label_full_path = sr_label_path + '/' + label
    data,*_ = read_img(label_full_path)

    data_copy = copy.deepcopy(data)
    for cnt, unique_value in enumerate(uni_val):
        data_copy[data == unique_value] = cnt
    
    output_full_path = output_label_path + '/' + label
    write_img(output_full_path, proj_temp, geotrans_temp, data_copy)

classnum 11


  6%|▌         | 480/7842 [00:08<02:23, 51.18it/s]