In [10]:
#!/.conda/envs/dp python
# -*- coding: utf-8 -*-

"""
图像转为8位uint格式
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
wanghaoyu191@mails.ucas.ac.cn
"""
from pathlib import Path
import gdal
import os
import ogr
import osr
import sys
import math
from osgeo.ogr import Geometry, Layer
from tqdm import tqdm
import numpy as np
import fnmatch
import copy

In [11]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img

    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data

    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data)
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img

    Args:
        sr_img: The full path of the original image

    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)

    del im_dataset

    return im_data, im_proj, im_geotrans

In [12]:

input_folder = r'E:\project_UAV_GF2_2\2-clip_img_UAV_321'  # 输入文件夹路径
output_folder = r'E:\project_UAV_GF2_2\2-clip_img_UAV_321_8bit'  # 输出文件夹路径

img_type = '*.tif'

# 确保输出文件夹存在
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

listpic = fnmatch.filter(os.listdir(input_folder), img_type)

'''逐个读取影像'''
for img in listpic:
    img_full_path = input_folder + '/' + img
    data, proj_temp, geotrans_temp = read_img(img_full_path)
    img_shape = data.shape
    if len(img_shape) == 2:
        output_data = np.zeros((img_shape[0], img_shape[1]), dtype=np.uint8)
        min_value = np.min(data)
        max_value = np.max(data)
        print(min_value, max_value) # 注意最小值和最大值
        output_data = ((data - min_value) / (max_value - min_value) * 255).astype(np.uint8)
    else:
        output_data = np.zeros((img_shape[0], img_shape[1], img_shape[2]), dtype=np.uint8)
        for i in range(img_shape[0]): # 读取每个波段
            data_temp = data[i, :, :]
            # 对像素值进行缩放和截断操作
            data_temp = np.nan_to_num(data_temp, nan=0) # 处理nan
            min_value = np.min(data_temp)
            max_value = np.max(data_temp)
            print(min_value, max_value) # 注意最小值和最大值
            output_array = ((data_temp - min_value) / (max_value - min_value) * 255).astype(np.uint8)
            output_data[i, :, :] = output_array
    
    output_full_path = output_folder + '/' + img
    write_img(output_full_path, proj_temp, geotrans_temp, output_data)

0 2275.7104
0 2135.0203
0 1915.6339
0 2175.848
0 2006.7286
0 1734.491
0 1776.831
0 1677.7324
0 1432.8801
0 2096.5042
0 2050.7434


KeyboardInterrupt: 