In [11]:
# -*- coding: utf-8 -*-

"""
根据小块影像裁剪大影像中的对应部分
~~~~~~~~~~~~~~~~
code by wHy
Aerospace Information Research Institute, Chinese Academy of Sciences
Ghent University
Haoyu.Wang@ugent.be
"""

from tqdm import tqdm
import gdal
import ogr
import fnmatch
from PIL import Image
from numpy import *
from pylab import *
from scipy.ndimage import filters
import cv2
import matplotlib.pyplot as plt
import os

%matplotlib inline
%config InlineBackend.figure_format = 'svg'

In [12]:
def write_img(out_path, im_proj, im_geotrans, im_data):
    """output img

    Args:
        out_path: Output path
        im_proj: Affine transformation parameters
        im_geotrans: spatial reference
        im_data: Output image data

    """
    # identify data type 
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # calculate number of bands
    if len(im_data.shape) > 2:  
        im_bands, im_height, im_width = im_data.shape
    else:  
        im_bands, (im_height, im_width) = 1, im_data.shape

    # create new img
    driver = gdal.GetDriverByName("GTiff")
    new_dataset = driver.Create(
        out_path, im_width, im_height, im_bands, datatype)
    new_dataset.SetGeoTransform(im_geotrans)
    new_dataset.SetProjection(im_proj)
    if im_bands == 1:
        new_dataset.GetRasterBand(1).WriteArray(im_data.squeeze())
    else:
        for i in range(im_bands):
            new_dataset.GetRasterBand(i + 1).WriteArray(im_data[i])

    del new_dataset

def read_img(sr_img):
    """read img

    Args:
        sr_img: The full path of the original image

    """
    im_dataset = gdal.Open(sr_img)
    if im_dataset == None:
        print('open sr_img false')
        sys.exit(1)
    im_geotrans = im_dataset.GetGeoTransform()
    im_proj = im_dataset.GetProjection()
    im_width = im_dataset.RasterXSize
    im_height = im_dataset.RasterYSize
    im_bands = im_dataset.RasterCount 
    im_data = im_dataset.ReadAsArray(0, 0, im_width, im_height)
    del im_dataset

    return im_data, im_proj, im_geotrans, im_height, im_width, im_bands

In [13]:
sr_img_full_path = r'D:\xj_populus_count\sr_img_3_bands\sr_img_3_bands.tif'
img_path = r'C:\Users\75198\OneDrive\论文\SCI-4-\画图\图4-语义分割示意图\清晰区域8-5-3波段\原始图像'
output_path = r'C:\Users\75198\OneDrive\论文\SCI-4-\画图\图4-语义分割示意图\清晰区域8-5-3波段'

listpic = fnmatch.filter(os.listdir(img_path), '*.tif')

# 读取大影像信息
sr_data, sr_proj, sr_geotrans, sr_height, sr_width, sr_bands= read_img(sr_img_full_path)

print(sr_geotrans)

for img_name in listpic:
    img_full_path = img_path + '/' + img_name
    im_data, im_proj, im_geotrans, im_height, im_width, im_bands= read_img(img_full_path)
    
    # 计算影像B在影像A中的起始位置（像素坐标）
    offset_x = int((im_geotrans[0] - sr_geotrans[0]) / sr_geotrans[1])
    offset_y = int((im_geotrans[3] - sr_geotrans[3]) / sr_geotrans[5])

    print(offset_x, offset_y)

    output_data = sr_data[:, offset_y:offset_y+im_height, offset_x:offset_x+im_width].copy()

    output_full_path = output_path + '/' + img_name[:-4] + '_from_big.tif'
    write_img(output_full_path, im_proj, im_geotrans, output_data)

(613264.0, 0.5, 0.0, 4446456.0, 0.0, -0.5)
8969 3653
7945 10203
9148 12632
5960 15749
1998 4688
6663 12358
