## Extract patch images from patch annotations
Segmentation or Object Detection

## Patch for Semantic Segmentation
Read Json file and extract patch images from original image

It provides augmentation(flip, rotation) function.

In [None]:
import numpy as np

def get_subimg(img, patch_size):
    """
    Get a sub-image which of the minimum size to include patch areas
    to crop only object included patch images.
    Size of the Sub-image is a mutiple of patch_size.

    Args:
        img(ndarray): Original image array. shape:(h, w, 3)
        patch_size(int): Patch size.
    Returns:
        (ndarray): Sub-image. shape: (h, w)
    """
    bimg = np.array(img[...,0], dtype=bool)
    (height, width) = bimg.shape
    y, x= np.where(bimg)
    del bimg
    
    tx = min(x)
    ty = min(y)
    bx = max(x)
    by = max(y)

    sub_width = bx - tx + 1
    if sub_width % patch_size != 0:
        sub_width = patch_size * (sub_width // patch_size + 1)
            
    sub_height = by - ty + 1
    if sub_height % patch_size != 0:
        sub_height = patch_size * (sub_height // patch_size + 1)
    
    cx = int((bx - tx) / 2) + tx
    cy = int((by - ty) / 2) + ty

    tx = cx - int(sub_width / 2)
    bx = tx + sub_width - 1
    ty = cy - int(sub_height / 2)
    by = ty + sub_height - 1

    if tx < 0:
        tx = 0
        bx = sub_width
    elif bx >= width:
        tx = width - sub_width
        bx = width - 1

    if ty < 0:
        ty = 0
        by = sub_height
    elif by >= height:
        ty = height - sub_height
        by = height - 1
    
    return img[ty:by+1, tx:bx+1]

In [None]:
import cv2
import os

# User input parameters
###############################
home = '../preliminary'
save_home = '../preliminary/masks'
os.makedirs(save_home, exist_ok=True)

patch_size = 224
ovr_ratio = 0.3 # ovr_ratio must be in range of [0, 1)
apply_data_augmentation = True
###############################

paths = [os.path.join(home, name) for name in os.listdir(home) if 'png' in name]
num_patches = 0

for path in paths:
    img = cv2.imread(path)  # black(0,0,0) or white(255,255,255) image    
    img = get_subimg(img, patch_size)
    base_name = os.path.basename(path).split('.')[0]
    
    height, width, _ = img.shape
    for c in range(0, height - int(patch_size * ovr_ratio), patch_size - int(patch_size * ovr_ratio)):
        for r in range(0, width - int(patch_size * ovr_ratio), patch_size - int(patch_size * ovr_ratio)):
            sub_img = img[c:c+patch_size, r:r+patch_size]
            
            if 255 not in sub_img[..., 0]:                
                continue  # Empty image
                
            name = f'{base_name}_{r}_{c}.png'
            cv2.imwrite(os.path.join(save_home, name), sub_img)
            num_patches += 1
            
            if apply_data_augmentation:
                aug_img = cv2.flip(sub_img, 0)  # Vertical flip
                aug_name = f'{base_name}_{r}_{c}_vflip.png'
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)

                aug_img = cv2.flip(sub_img, 1)  # Horizontal flip
                aug_name = f'{base_name}_{r}_{c}_hflip.png'
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)

                aug_img = cv2.rotate(sub_img, cv2.ROTATE_90_CLOCKWISE)  # Rotate 90 degree
                aug_name = f'{base_name}_{r}_{c}_r90.png'
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)

                aug_img = cv2.rotate(sub_img, cv2.ROTATE_180)  # Rotate 180 degree
                aug_name = f'{base_name}_{r}_{c}_r180.png'
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)

                aug_img = cv2.rotate(sub_img, cv2.ROTATE_90_COUNTERCLOCKWISE)  # Rotate 270 degree
                aug_name = f'{base_name}_{r}_{c}_r270.png'
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)

                aug_img = cv2.flip(cv2.rotate(sub_img, cv2.ROTATE_90_CLOCKWISE), 1)
                aug_name = f'{base_name}_{r}_{c}_r90hflip.png'  # Rotate 90 and Horizontal flip
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)
            
                aug_img = cv2.rotate(cv2.flip(sub_img,1), cv2.ROTATE_90_CLOCKWISE)
                aug_name = f'{base_name}_{r}_{c}_hflipr90.png'  # Horizontal flip and Rotate 90
                cv2.imwrite(os.path.join(save_home, aug_name), aug_img)
                
                num_patches += 7

print('Num. object patches: ', num_patches)

## Patch for Object Detection
Read XML label file and extract patch images from original image

In [None]:
import cv2
import os
import json
import xml.etree.ElementTree as ET

home = '/media/data1/Ace/11000_Marine_objects/11200_Dataset/11220_Images/11222_Sentinel-1'
img_home = os.path.join(home, 'PNG')
ann_home = '/home/sjhong/work/script/work/Annotations'
save_home = os.path.join('/home/sjhong/work/script/work/patch')
os.makedirs(save_home, exist_ok=True)

size = 1024  # Size of bounding box
empty_img = []
num_patch = 0

def check_box(box, size=1024):
    """Check box size
    
    Check if box size is the same as "size"
    Args:
        box(list[int,]): bounding box. [xmin, ymin, xmax, ymax]
        size(int): size of bouding box. default: 1024
        
    Returns:
        (bool): True if box size is the same as "size"
    """
    width = box[2] - box[0]
    height = box[3] - box[1]
    
    if width == size and height == size:
        return True
    
    return False

def correct_box(box, size=1024):
    if box[0] < 0:
        box[0] = 0
    if box[1] < 0:
        box[1] = 0
        
    box[2] = box[0] + size
    box[3] = box[1] + size
        
    return box

for img_name in os.listdir(img_home):
    print('Processing {}'.format(img_name))
    img_path = os.path.join(img_home, img_name)
    ann_path = os.path.join(ann_home, img_name[:-3] + 'xml')
    
    tree = ET.parse(ann_path)
    root = tree.getroot()
    objs = root.findall('object')
    if objs:
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        for obj in objs:
            # Coord of labelImg starts from 1
            bndbox = [int(coord.text) for coord in obj.find('bndbox').getchildren()]  # bndbox = [xmin, ymin, xmax, ymax]
            bndbox = correct_box(bndbox)
            
            if check_box(bndbox):
                patch_img = img[bndbox[1]:bndbox[1]+size, bndbox[0]:bndbox[0]+size]
                patch_name = '{}_{}_{}_{}_{}.png'.format(img_name[:-4], bndbox[0], bndbox[0]+size, bndbox[2]-1, bndbox[2]+size)
                cv2.imwrite(os.path.join(save_home, patch_name), patch_img)
                num_patch += 1
            else:
                print('Wrong boxes: ', bndbox)
    else:
        empty_img.append(img_name)
        
if empty_img:
    print('Annotations not eixst: ', empty_img)
print('Total number of patches: ', num_patch)
print('Done')