# Mount to your Drive
Note that annotations and images must be in your drive.

In [None]:
# Connection to Google Drive
from google.colab import drive
drive.mount('/content/drive')

# Classes and function to get subframes

In [None]:
# Install required package
! pip install -U git+https://github.com/albu/albumentations

In [None]:
import os
from PIL import Image
from albumentations import Compose, BboxParams, Crop
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import math
import numpy as np
import torch
import torchvision
from torch.utils.data import Dataset
import json
import time
import datetime
from datetime import date
import csv

class Subframes(object):
    ''' 
    Class allowing the visualisation and the cropping of a labeled 
    image (bbox) into sub-frames whose dimensions are specified 
    by the user.

    Attributes
    -----------
    img_name : str
        name of the image (with extension, e.g. "My_image.JPG").
    image : PIL
        PIL image.
    target : dict
        Must have 'boxes' and 'labels' keys at least.
        'boxes' must be a list in the 'coco' bounding box format :
        [[xmin, ymin, width, height], ...]
    width : int
        width of the sub-frames
    height : int
        height of the sub-frames
    strict : bool
        set to True get sub-frames of exact same size 
        (e.g width x height) (default: False)
    
    Methods
    --------
    getlist(overlap=False)
        Produces a results list containing, for each row :
        the sub-frame (3D list, dtype=uint8), the bboxes (2D list),
        the labels (1D list) and the filename (str).
    visualise(results)
        Displays ordered sub-frames of the entire image.
    topoints(results)
        Converts the bounding boxes into points annotations.
    displayobjects(results, points_results, ann_type='point')
        Displays only sub-frames containing objects.
    save(results, output_path, object_only=True)
        Saves sub-frames to a specific path.
    '''

    def __init__(self, img_name, image, target, width, height, strict=False):
        '''
        Parameters
        -----------
        img_name : str
            name of the image (with extension, e.g. "My_image.JPG")
        image : PIL
            PIL image
        target : dict
            Must have 'boxes' and 'labels' keys at least.
        width : int
            width of the sub-frames
        height : int
            height of the sub-frames
        strict : bool
            set to True get sub-frames of exact same size 
            (e.g width x height) (default: False)
        '''

        self.img_name = img_name
        self.image = image
        self.target = target
        self.width = width
        self.height = height
        self.strict = strict

        self.img_width = image.size[0]
        self.img_height = image.size[1]

        self.x_sub = 1 + int((self.img_width - (self.img_width % width)) / width)
        self.y_sub = 1 + int((self.img_height - (self.img_height % height)) / height)

    def getlist(self, overlap=False):
        '''
        Produces a results list containing, for each row :
        the sub-frame (3D list, dtype=uint8), the bboxes (2D list),
        the labels (1D list) and the filename (str).
        Parameters
        -----------
        overlap : bool, optional
            Set to True to get an overlap of 50% between 
            2 sub-frames (default: False)
        Returns
        --------
        list
        '''
        height = self.height
        width = self.width
        img_height = self.img_height
        img_width = self.img_width

        results = []

        # Image preprocessing      
        image_np = np.array(self.image)
        boxes = self.target['boxes']
        labels = self.target['labels']
        annotations = {'image':image_np,'bboxes':boxes,'labels':labels}

        # Crop lists
        if overlap is True:
            overlap = 0.5
            y_sub = int(np.round(height*overlap))
            x_sub = int(np.round(width*overlap))
            rg_ymax = img_height-y_sub
            rg_xmax = img_width-x_sub
        else:
            y_sub = height
            x_sub = width
            rg_ymax = img_height
            rg_xmax = img_width

        crops = []

        for y in range(0, rg_ymax, y_sub):
            if  y+height <= img_height:
                for x in range(0, rg_xmax, x_sub):
                    if  x+width <= img_width:
                        xmin, ymin = x, y
                        xmax, ymax = x+width, y+height
                    elif x+img_width%width <= img_width:
                        xmin, ymin = img_width - width, y
                        xmax, ymax = x+img_width%width, y+height

                    if self.strict is True:
                        crops.append([xmin, ymin, xmax, ymax])
                    else:
                        crops.append([x, y, xmax, ymax])
            
            elif  y+img_height%height <= img_height:
                for x in range(0, rg_xmax, x_sub):
                    if  x+width <= img_width:
                        xmin, ymin = x, img_height - height
                        xmax, ymax = x+width, y+img_height%height
                    elif x+img_width%width <= img_width:
                        xmin, ymin = img_width - width, img_height - height
                        xmax, ymax = x+img_width%width, y+img_height%height

                    if self.strict is True:
                        crops.append([xmin, ymin, xmax, ymax])
                    else:
                        crops.append([x, y, xmax, ymax])

        sub = 0
        for xmin, ymin, xmax, ymax in crops:
            transf = Compose([Crop(xmin, ymin, xmax, ymax, p=1.0)], 
                                bbox_params=BboxParams(format='coco',
                                                        min_visibility=0.25, 
                                                        label_fields=['labels']))
            augmented  = transf(**annotations)
            sub_name = self.img_name.rsplit('.')[0] + "_S" + str(sub) + ".JPG"
            results.append([augmented['image'],augmented['bboxes'],augmented['labels'],sub_name])
            sub += 1

        return results

    def visualise(self, results):
        '''
        Displays ordered sub-frames of the entire image.
        Parameters
        -----------
        results : list
            The list obtained by the method getlist().
        Returns
        --------
        matplotlib plot
        '''

        if len(results) > (self.x_sub*self.y_sub):
            x_sub = 2*self.x_sub - 2
            y_sub = 2*self.y_sub - 2
        else:
            x_sub = self.x_sub
            y_sub = self.y_sub

        plt.figure(1)
        plt.suptitle(self.img_name)
        sub = 1
        for line in range(len(results)):

            if self.img_width % self.width != 0:
                n_col = x_sub
                n_row = y_sub
            else:
                n_col = x_sub - 1
                n_row = y_sub - 1

            plt.subplot(n_row, n_col, sub, xlim=(0,self.width), ylim=(self.height,0))
            plt.imshow(Image.fromarray(results[line][0]))
            plt.axis('off')
            plt.subplots_adjust(wspace=0.1,hspace=0.1)

            text_x = np.shape(results[line][0])[1]
            text_y = np.shape(results[line][0])[0]

            if self.width > self.height:
                f = self.height*(self.y_sub/y_sub)
            else:
                f = self.width*(self.x_sub/x_sub)

            plt.text(0.5*text_x, 0.5*text_y, 
                    "S"+str(line),
                    horizontalalignment='center',
                    verticalalignment='center',
                    fontsize=0.02*f,
                    color='w')
            sub += 1

    def topoints(self, results):
        '''
        Converts the bounding boxes into points annotations.
        Parameters
        -----------
        results : list
            The list obtained by the method getlist().
        Returns
        --------
        list
            A 2D list with headers : "id", "filename", "count",
            "locations" where
            - "id" represents the unique id of the sub-frame within 
              the image
            - "filename" is the name of the sub-frame 
              (e.g. "My_image_S1.JPG")
            - "count" is the number of objects into the sub-frame
            - "points" is a list of tuple representing the 
              locations of the objects (y,x)
    
        '''

        points_results = [['id','filename','count','locations']]
        loc = []
        for line in range(len(results)):
            # Verify that bbox exists
            if results[line][1]:
                count = len(results[line][1])
                for bbox in range(len(results[line][1])):
                    boxe = results[line][1][bbox]
                    x = int(boxe[0]+(boxe[2])/2)
                    y = int(boxe[1]+(boxe[3])/2)
                    point = (y,x)
                    loc.append(point)
            
                sub_name = self.img_name.rsplit('.')[0] + "_S" + str(line) + ".JPG"
                points_results.append([line, sub_name, count, loc])
                loc = []

        return points_results

    def displayobjects(self, results, points_results, ann_type='point'):
        '''
        Displays only sub-frames containing objects.
        Parameters
        -----------
        results : list
            The list obtained by the method getlist().
        points_results : list
            The list obtained by the method topoints(results).
        ann_type : str, optional
            A string used to specify the annotation type. Choose
            between :
            - 'point' to visualise points
            - 'bbox' to visualise bounding boxes
            - 'both' to visualise both
            (default is 'point')
        Returns
        --------
        matplotlib plot
        '''

        sub_r = 0
        sub_c = 0

        n_row = int(np.round(math.sqrt(len(points_results)-1)))
        n_col = n_row

        if int(len(points_results)-1) > int(n_row*n_col):
            n_row += 1

        fig, ax = plt.subplots(nrows=n_row, ncols=n_col, squeeze=False)

        for r in range(n_row):
            for c in range(n_col):
                ax[r,c].axis('off')
                plt.subplots_adjust(wspace=0.1,hspace=0.1)

        for o in range(1,len(points_results)):

            id_object = points_results[o][0]
            patch_object = results[id_object][0]

            text_x = np.shape(results[id_object][0])[1]
            text_y = np.shape(results[id_object][0])[0]

            # Plot
            ax[sub_r,sub_c].imshow(Image.fromarray(patch_object))
            ax[sub_r,sub_c].text(0.5*text_x, 0.5*text_y, 
                    "S"+str(id_object),
                    horizontalalignment='center',
                    verticalalignment='center',
                    fontsize=15,
                    color='w',
                    alpha=0.6)

            if ann_type == 'point':
                points = points_results[o][3]
                for p in range(len(points)):
                    ax[sub_r,sub_c].scatter(points[p][1],points[p][0], color='r')
            
            elif ann_type == 'bbox':
                bboxes = results[id_object][1]
                for b in range(len(bboxes)):
                    rect = patches.Rectangle((bboxes[b][0],bboxes[b][1]),bboxes[b][2],bboxes[b][3], linewidth=1, edgecolor='r', facecolor='none')
                    ax[sub_r,sub_c].add_patch(rect)
                
            elif ann_type == 'both':
                points = points_results[o][3]
                bboxes = results[id_object][1]
                for b in range(len(bboxes)):
                    ax[sub_r,sub_c].scatter(points[b][1],points[b][0], color='b')
                    rect = patches.Rectangle((bboxes[b][0],bboxes[b][1]),bboxes[b][2],bboxes[b][3], linewidth=1, edgecolor='r', facecolor='none')
                    ax[sub_r,sub_c].add_patch(rect)

            else:
                raise ValueError('Annotation of type \'{}\' unsupported. Choose between \'point\',\'bbox\' or \'both\'.'.format(ann_type))
                
            if sub_c < n_col-1:
                sub_r = sub_r
                sub_c += 1
            else:
                sub_c = 0
                sub_r += 1
            
    def save(self, results, output_path, object_only=True):
        '''
        Saves sub-frames (.JPG) to a specific path.
        Parameters
        -----------
        results : list
            The list obtained by the method getlist().
        output_path : str
            The path to the folder chosen to save sub-frames.
        object_only : bool, optional
            A flag used to choose between :
            - saving all the sub-frames of the entire image
              (set to False)
            - saving only sub-frames with objects
              (set to True, default)
        Returns
        --------
        None
        '''

        for line in range(len(results)):
            if object_only is True:
                if results[line][1]:
                    subframe = Image.fromarray(results[line][0])
                    sub_name =  results[line][3]
                    subframe.save(os.path.join(output_path, sub_name))
                    
            elif object_only is not True:
                subframe = Image.fromarray(results[line][0])
                sub_name =  results[line][3]
                subframe.save(os.path.join(output_path, sub_name))

class CustomDataset(Dataset):

    def __init__(self, img_root, ann_root, target_type='coco', transforms=None):

        self.img_root = img_root
        self.ann_root = ann_root
        self.target_type = target_type
        self.transforms = transforms

        with open(ann_root) as json_file:
            self.data = json.load(json_file)

    def __getitem__(self, idx):

        image_name = self.data['images'][idx]['file_name']
        image_id = self.data['images'][idx]['id']

        img_path = os.path.join(self.img_root,image_name)

        img = Image.open(img_path).convert("RGB")

        target = {}

        boxes = []
        area = []
        labels = []

        for ann in range(len(self.data['annotations'])):

            if self.data['annotations'][ann]['image_id']== image_id:

                if self.target_type == 'coco':
                    boxes.append(self.data['annotations'][ann]['bbox'])

                elif self.target_type == 'pascal':
                    bndbox = self.data['annotations'][ann]['bbox']
                    xmin = bndbox[0]
                    ymin = bndbox[1]
                    xmax = bndbox[0] + bndbox[2]
                    ymax = bndbox[1] + bndbox[3]
                    boxes.append([xmin,ymin,xmax,ymax])

                labels.append(self.data['annotations'][ann]['category_id'])
                area.append(self.data['annotations'][ann]['area'])

        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = image_id
        target["area"] = area

        return img, target

    def __len__(self):
        return len(self.data['images'])

def subexport(img_root, ann_root, width, height, output_folder, 
            overlap=False, strict=False ,pr_rate=50, 
            object_only=True, export_ann=True):
    '''
    Function that exports sub-frames created on the basis of 
    images loaded by a dataloader, and their associated new 
    annotations.

    This function uses the 'subframes' class for image processing.

    Parameters
    -----------
    img_root : str
        Path to images.

    ann_root : str
        Path to a coco-style dict (.json) containing annotations of 
        the initial dataset.

    width : int
        Width of the sub-frames.
    
    height : int
        Height of the sub-frames.
    
    output_folder : str
        Output folder path where to save sub-frames and new annotations.
    
    overlap : bool, optional
        Set to True to get an overlap of 50% between 
        2 sub-frames (default: False)
    
    strict : bool, optional
        Set to True get sub-frames of exact same size 
        (e.g width x height) (default: False)

    pr_rate : int, optional
        Console print rate of image processing progress.
        Default : 50
    
    object_only : bool, optional
        A flag used to choose between :
            - saving all the sub-frames of the entire image
            (set to False)
            - saving only sub-frames with objects
            (set to True, default)

    export_ann : bool, optional
        A flag used to choose between :
            - not exporting annotations with sub-frames
            (set to False)
            - exporting annotations with sub-frames
            (set to True, default
   
    Returns
    --------
    list

    a coco-type JSON file named 'coco_subframes.json'
    is created inside the subframes' folder
    
    '''

    # Get annos
    with open(ann_root) as json_file:
        coco_dic = json.load(json_file)

    # Dataset
    dataset = CustomDataset(img_root, ann_root, target_type='coco')

    # Sampler
    sampler = torch.utils.data.SequentialSampler(dataset)

    # Collate_fn
    def collate_fn(batch):
        return tuple(zip(*batch))

    # Dataloader
    dataloader = torch.utils.data.DataLoader(dataset, 
                                            batch_size=1,
                                            sampler=sampler,
                                            num_workers=0,
                                            collate_fn=collate_fn)

    # Header
    all_results = [['filename','boxes','labels','HxW']]

    # intial time
    t_i = time.time()

    for i, (image, target) in enumerate(dataloader):

        if i == 0:
            print(' ')
            print('-'*38)
            print('Sub-frames creation started...')
            print('-'*38)

        elif i == len(dataloader)-1:
            print('-'*38)
            print('Sub-frames creation finished!')
            print('-'*38)

        image = image[0]
        target = target[0]

        # image id and name
        img_id = int(target['image_id'])
        for im in coco_dic['images']:
            if im['id'] == img_id:
                img_name = im['file_name']

        # Get subframes
        sub_frames = Subframes(img_name, image, target, width, height, strict=strict)
        results = sub_frames.getlist(overlap=overlap)

        # Save
        sub_frames.save(results, output_path=output_folder, object_only=object_only)
        
        if object_only is True:
            for b in range(len(results)):
                if results[b][1]:
                    h = np.shape(results[b][0])[0]
                    w = np.shape(results[b][0])[1]
                    all_results.append([results[b][3],results[b][1],results[b][2],[h,w]])

        elif object_only is not True:
            for b in range(len(results)):
                h = np.shape(results[b][0])[0]
                w = np.shape(results[b][0])[1]
                all_results.append([results[b][3],results[b][1],results[b][2],[h,w]])

        if i % pr_rate == 0:
            print('Image [{:<4}/{:<4}] done.'.format(i, len(coco_dic['images'])))

    # final time
    t_f = time.time()

    print('Elapsed time : {}'.format(str(datetime.timedelta(seconds=int(np.round(t_f-t_i))))))
    print('-'*38)
    print(' ')

    return_var = np.array(all_results)[:,:3].tolist()

    # Export new annos
    if export_ann is True:
        file_name = 'coco_subframes.json'
        output_f = os.path.join(output_folder, file_name)

        # Initializations
        images = []
        annotations = []
        id_img = 0
        id_ann = 0

        for i in range(1,len(all_results)):
            
            id_img += 1

            h = all_results[i][3][0]
            w = all_results[i][3][1]

            dico_img = {
                "license": 1,
                "file_name": all_results[i][0],
                "coco_url": "None",
                "height": h,
                "width": w,
                "date_captured": "None",
                "flickr_url": "None",
                "id": id_img
            }

            images.append(dico_img)

            # Bounding boxes
            if all_results[i][1]:
                
                bndboxes = all_results[i][1]

                for b in range(len(bndboxes)):

                    id_ann += 1

                    bndbox = bndboxes[b]
                    
                    # Convert 
                    x_min = int(np.round(bndbox[0]))
                    y_min = int(np.round(bndbox[1]))
                    box_w = int(np.round(bndbox[2]))
                    box_h = int(np.round(bndbox[3]))

                    coco_box = [x_min,y_min,box_w,box_h]

                    # Area
                    area = box_w*box_h

                    # Label
                    label_id = all_results[i][2][b]

                    # Store the values into a dict
                    dico_ann = {
                            "segmentation": [[]],
                            "area": area,
                            "iscrowd": 0,
                            "image_id": id_img,
                            "bbox": coco_box,
                            "category_id": label_id,
                            "id": id_ann
                    }

                    annotations.append(dico_ann)
        
        # Update info
        coco_dic['info']['date_created'] = str(date.today())
        coco_dic['info']['year'] = str(date.today().year)

        new_dic = {
            'info': coco_dic['info'],
            'licenses': coco_dic['licenses'],
            'images': images,
            'annotations': annotations,
            'categories': coco_dic['categories']
        }

        # Export json file
        with open(output_f, 'w') as outputfile:
            json.dump(new_dic, outputfile)

        if os.path.isfile(output_f) is True:
            print('File \'{}\' correctly saved at \'{}\'.'.format(file_name, output_folder))
            print(' ')
        else:
            print('An error occurs, file \'{}\' not found at \'{}\'.'.format(file_name, output_folder))

    return return_var


# Export sub-frames
We used 
* `~/ground_truth/general_dataset/json/big_size/train_big_size_A_B_E_K_WH_WB.json` for training's sub-frames.

* `~/ground_truth/general_dataset/json/big_size/val_big_size_A_B_E_K_WH_WB.json` for validation's sub-frames (for using evaluation tools of mmdetection), with `object_only = False`.

* `~/ground_truth/general_dataset/json/big_size/test_big_size_A_B_E_K_WH_WB.json` for test's sub-frames (for using evaluation tools of mmdetection), with `object_only = False`.

You can find theses subframes in `mmdetection/data/mammals`.


In [None]:
# Fill in with your paths
images_folder = '< your specific path here >'
annos_path = '< your JSON annotations file path here >'
output_folder = '< your specific path here >'

# Export
json_dic = subexport(
        img_root = images_folder,
        ann_root = annos_path,
        width = 2000,
        height = 2000,
        output_folder = output_folder,
        strict = True,
        object_only = True
    )