In [13]:
# -*- coding: utf-8 -*-

"""
图像增强
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
Ghent University
Haoyu.Wang@ugent.be
"""

from tqdm import tqdm
import gdal
import ogr
import fnmatch
from PIL import Image
from numpy import *
from pylab import *
from scipy.ndimage import filters
import cv2
import matplotlib.pyplot as plt
import os

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [14]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img

    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data

    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data.squeeze())
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img

    Args:
        sr_img: The full path of the original image

    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_bands = im_dataset.RasterCount 
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans, im_height, im_width, im_bands

In [15]:
def truncated_linear_stretch(image, truncated_value, max_out = 255, min_out = 0):
    def gray_process(gray):
        gray = np.where(gray == -9999, np.nan, gray) # 清理背景值
        truncated_down = np.nanpercentile(gray, truncated_value) # 忽视背景值计算阈值
        truncated_up = np.nanpercentile(gray, 100 - truncated_value)
        gray = (gray - truncated_down) / (truncated_up - truncated_down) * (max_out - min_out) + min_out 
        gray[gray < min_out] = min_out
        gray[gray > max_out] = max_out
        if(max_out <= 255):
            gray = np.uint8(gray)
        elif(max_out <= 65535):
            gray = np.uint16(gray)
        return gray
    
    #  如果是多波段
    if(len(image.shape) == 3):
        image_stretch = []
        for i in range(image.shape[0]):
            gray = gray_process(image[i])
            image_stretch.append(gray)
        image_stretch = np.array(image_stretch)
    #  如果是单波段
    else:
        image_stretch = gray_process(image)
    return image_stretch

In [16]:
img_path = r'E:\project_UAV_GF2_2\2-clip_img_GF2_432'
output_path = r'E:\project_UAV_GF2_2\3-clip_img_GF2_432_enhanced'

listpic = fnmatch.filter(os.listdir(img_path), '*.tif')

for img_name in listpic:
    img_full_path = img_path + '/' + img_name
    im_data, im_proj, im_geotrans = read_img(img_full_path)[:3]
    
    data_stretch = truncated_linear_stretch(im_data, 0.005, max_out=255) # 拉伸

    output_full_path = output_path + '/' + img_name[:-4] + '.tif'
    write_img(output_full_path, im_proj, im_geotrans, data_stretch)