In [29]:
#!/.conda/envs/learn python
# -*- coding: utf-8 -*-

"""
灰度标签或预测结果转颜色图
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
wanghaoyu191@mails.ucas.ac.cn
"""

'\n灰度标签或预测结果转颜色图\n~~~~~~~~~~~~~~~~\ncode by wHy\nAerospace Information Research Institute, Chinese Academy of Sciences\nwanghaoyu191@mails.ucas.ac.cn\n'

In [30]:
from torchvision import transforms
import numpy as np
import os
import fnmatch
import torch
from tqdm import tqdm
import gdal
import ogr

In [31]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img

    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data

    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img

    Args:
        sr_img: The full path of the original image

    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans


def gray2color(image_gray, rgb_mapping):
    image_gray_shape = np.shape(image_gray)
    img_rgb = np.zeros(shape=(3, image_gray_shape[0], image_gray_shape[1]), dtype=np.uint8)
    image_gray = np.array(image_gray)

    for map_idx, rgb in enumerate(rgb_mapping):
        idx = np.where(image_gray == map_idx)
        for i in range(np.shape(idx)[1]):
            img_rgb[:, idx[0][i], idx[1][i]] = rgb

    return img_rgb

In [32]:
images_path = r'C:\Users\75198\OneDrive\论文\SCI-3-3 Remote sensing data augmentation\图片\7-预测结果展示图\待预测原始影像\predict_result_ASM_LV1'
output_path = r'C:\Users\75198\OneDrive\论文\SCI-3-3 Remote sensing data augmentation\图片\7-预测结果展示图\待预测原始影像\predict_result_ASM_LV1'

if not os.path.exists(output_path):
    os.mkdir(output_path)

image_list = fnmatch.filter(os.listdir(images_path), '*.tif')  # 过滤出tif文件

# 创建自定义颜色映射表
cmap = np.array(
    [
        (0, 0, 0), #背景值
        (205, 245, 122),
        (122, 142, 245),
    ],
    dtype=np.uint8,
)

for img_name in image_list:
    img_full_path = os.path.join(images_path + '/' + img_name)
    output_full_path = output_path + '/' + img_name[:-4] + '.png'

    '''读取img数据'''
    data, proj_temp, geotrans_temp = read_img(img_full_path)

    data_output = gray2color(data, cmap)

    write_img(output_full_path, proj_temp, geotrans_temp, data_output)
