In [None]:
import os
import torch
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import cv2
import rasterio
import glob
from affine import Affine

data_dict = {'planet': {'dir': 'full_size_planet_tif', 'image_bits': 12, 'bands': 'BGRN'},
             'kompsat': {'dir': 'full_size_kompsat_tif', 'image_bits': 14, 'bands': 'RGBN'}}


def open_geotiff(img_file, bands='RGBN'):
    with rasterio.open(img_file) as f:
        img = f.read()  # (C, H, W)
        img = img.transpose(1,2,0).astype(np.float32)  # (H, W, C), RGB+NIR (4 bands)

    if bands == 'BGRN':
        img = img[:, :, [2, 1, 0, 3]]
    elif bands == 'BGR':
        img = img[:, :, [2, 1, 0]]

    return img  

def create_img_label_patches(img_file, patch_size=512, patch_stride=512, num_patches=100, out_dir='./data/patches'):   
        img_patch_files = []
        print('creating patches for %s...' % img_file, end=' ')
        # image patches
        img_file = str(img_file)
        if not os.path.exists(img_file):
                raise Exception('%s not found' % img_file)
        img = open_geotiff(img_file, bands="BGRN")  # (H, W, C), RGB+NIR (4 bands), 14-bit image
        h, w = img.shape[:2]  

        # numpy arrays to tensors
        img = torch.from_numpy(img.transpose(2, 0, 1)).to(dtype=torch.int16)  # (H, W, C) to (C, H, W)

        pad_h = int((np.ceil(h / patch_stride) - 1) * patch_stride + patch_size - h)
        pad_w = int((np.ceil(w / patch_stride) - 1) * patch_stride + patch_size - w)
        padded_img = F.pad(img, pad=[0, pad_w, 0, pad_h])
        patches = padded_img.unfold(1, patch_size, patch_stride).unfold(2, patch_size, patch_stride) # [C, NH, NW, patch_size, patch_size]
        
        # select patches from the center of the image
        [C, NH, NW, patch_size, patch_size] = patches.shape
        patches_per_side = int(num_patches**0.5)
        
        if num_patches > NH * NW:
            raise Exception('num_patches should be less than or equal to %d' % (NH * NW))
        
        center_row = NH // 2
        center_col = NW // 2
    
        half_patches = patches_per_side // 2
        start_row = max(center_row - half_patches, 0)
        end_row = min(center_row + half_patches + (patches_per_side % 2), NH)
        
        start_col = max(center_col - half_patches, 0)
        end_col = min(center_col + half_patches + (patches_per_side % 2), NW)

        center_patches = patches[:, start_row:end_row, start_col:end_col, :, :] 

        tif_dir = Path(out_dir)/'tif'
        png_dir = Path(out_dir)/'png'
        os.makedirs(tif_dir, exist_ok=True)
        os.makedirs(png_dir, exist_ok=True)


        transform = Affine.translation(0, 0) * Affine.scale(1, -1)
        for y in range(center_patches.shape[1]):
            for x in range(center_patches.shape[2]):
                # save image patches into geotiff format
                patch_img = patches[:, y, x, :, :].contiguous().permute(1, 2, 0).numpy().astype(np.uint16)
                # filters patch_img that contains certain percentage of zero values
                if np.count_nonzero(patch_img) < 0.999 * patch_img.size:
                    continue
                with rasterio.open(
                    os.path.join(tif_dir, os.path.basename(img_file).rsplit('.tif')[0] + '_%d_%d.tif' % (y, x)),
                    'w',
                    driver='GTiff',
                    height=patch_img.shape[0],
                    width=patch_img.shape[1],
                    count=patch_img.shape[2],
                    dtype=rasterio.uint16,
                    crs='+proj=latlong',
                    transform=transform,
                ) as dst:
                    for i in range(patch_img.shape[2]):
                        dst.write(patch_img[:, :, i], i + 1)

                # save image patches into png format
                patch_img = (patch_img[...,:3] / (2**12 - 1) * 255.0).clip(0.0, 255.0).astype(np.uint8)
                patch_img = cv2.cvtColor(patch_img, cv2.COLOR_RGB2BGR)
                
                cv2.imwrite(os.path.join(png_dir, os.path.basename(img_file).rsplit('.tif')[0] + '_%d_%d.png' % (y, x)), patch_img)
                
        
        print('success')

        return img_patch_files


if __name__ == '__main__':
    root = Path('data/jbnu_cloud')
    img_files = (root/'full_size_planet_tif').glob('*.tif')
    for img_file in img_files:
        create_img_label_patches(img_file, patch_size=512, patch_stride=512, num_patches=150, out_dir='./data/patches/planet')