### Imports

In [None]:
import json
import os
import shutil
import random
import copy
from pandas.core.common import flatten
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance
import cv2
import albumentations as A
from collections import Counter
import time


### Image Augmentation Functions

In [None]:
def getRandTup(low_low, low_high, high_low, high_high):

    rand_low = random.uniform(low_low,low_high)
    rand_high = random.uniform(high_low, high_high)
    rand_pick = random.uniform(0,1)
    
    rand_use = 0
    
    if rand_pick < .5:
        rand_use = rand_low
    else:
        rand_use = rand_high
        
    rand_tup = (rand_use, rand_use)
    
    return rand_tup



In [None]:
def aug_brightness(og_img):
        
    rand_tup = getRandTup(.3, .8, 2, 3.5)
        
    aug = A.ColorJitter(brightness=rand_tup,contrast=0,saturation=0,hue=0, always_apply=True, p=1)
    aug_img =  aug(image=og_img)['image']
    return aug_img



In [None]:
def aug_contrast(og_img):
    
    rand_tup = getRandTup(.3, .8, 2, 3.5)
        
    aug = A.ColorJitter(brightness=0,contrast=rand_tup,saturation=0,hue=0, always_apply=True, p=1)
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def aug_saturation(og_img):
    
    rand_tup = getRandTup(.3, .8, 2, 3.5)
        
    aug = A.ColorJitter(brightness=0,contrast=0,saturation=rand_tup,hue=0, always_apply=True, p=1)
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def aug_hue(og_img):
    
    rand_tup = getRandTup(-.5, -.075, .075, .5)
        
    aug = A.ColorJitter(brightness=0,contrast=0,saturation=0,hue=rand_tup, always_apply=True, p=1)
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def aug_color_jitter(og_img):
        
    aug = A.ColorJitter(brightness=getRandTup(.3, .8, 1.5, 2.25),
                        contrast=getRandTup(.3, .8, 1.5, 2.25),
                        saturation=getRandTup(.3, .8, 1.5, 2.25),
                        hue=getRandTup(-.5, -.075, .075, .5), 
                        always_apply=True, p=1)
    
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def aug_posterize(og_img):
    
    aug = A.Posterize(num_bits=(1,2), always_apply=True, p=1)
    
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def aug_blur(og_img):
    
    aug = A.Blur(blur_limit=(5, 9), always_apply=True, p=1)
    
    aug_img =  aug(image=og_img)['image']
    return aug_img

In [None]:
def createImgJson (img_id, fn):
    
    json_obj = {
        'width': 1920,
        'height': 1080,
        'id': img_id,
        'file_name': fn
    }
    
    return json_obj

In [None]:
def createAnnJson (ann_id, img_id, cat_id, seg, bbox, ignore, iscrowd, area):
    
    json_obj = {
        'id': ann_id,
        'image_id': img_id,
        'category_id': cat_id,
        'segmentation': seg,
        'bbox': bbox,
        'ignore': ignore,
        'iscrowd': iscrowd,
        'area': area
    }
    
    return json_obj

### Annotation Stats Calculation Functions

In [None]:
def createCatMap (cat_list):
    cat_dict = {}
    for cat in cat_list:
        cat_dict.update({cat["id"]: cat["name"] })
    return cat_dict

In [None]:
def createImgAnnMap (img_list, ann_list):
    
    img_ann_dict = {}
    
    for img in img_list:
        tmp_ann_list = []
        for ann in ann_list:
            if ann["image_id"] == img["id"]:
                tmp_ann_list.append(ann)
        img_ann_dict.update({img["id"]: tmp_ann_list})
        
    return img_ann_dict


In [None]:
def calcImgStats (ann_list):
    
    tmp_cat_list = []
    for ann in ann_list:
        tmp_cat_list.append(ann["category_id"])
        
    return dict(Counter(tmp_cat_list))

In [None]:
def calcImgStatsAll (img_ann_map):
    
    img_stats_map = {}
    for i in range(0, len(img_ann_map)):
        ann_list = img_ann_map[i]
        img_stats_map.update({i: calcImgStats(ann_list)})
    
    return img_stats_map

In [None]:
def calcGlobStats (img_stats, cat_map):
    
    glob_stats = {}
        
    for key1 in range(0, len(cat_map)):
        total = 0
        for key2 in img_stats.keys():
            stats = img_stats[key2]
            try:
                total += stats[key1]
            except:
                pass
        glob_stats.update({key1: total})
        
    return glob_stats

### Create Images Function

In [None]:

def createNewImgs(img_list, img_list_og, ann_list_og, cat_list_og):

    img_list_master = img_list_og
    ann_list_master = ann_list_og

    img_id_new = img_list_og[-1]["id"] + 1
    ann_id_new = ann_list_og[-1]["id"] + 1

    for i, img in enumerate(img_list):
        
        fn_rel = img["file_name"]
        fn = fn_rel[2:-4]
        fn_full = os.path.join(in_dir, "images", fn_rel[2:])
        img_id_og = img["id"]
        
        print (str(i) + ")", fn_full)
        
        # Creating directory for images inside of coco output file
        img_out_dir = os.path.join(out_dir, "images")
        os.makedirs(img_out_dir, exist_ok=True)
        
        # Collecting og annotations json
        ann_og = []
        for ann in ann_list_og:
            if img_id_og == ann["image_id"]:
                ann_og.append(ann)  
        
        ### FIX ###
        
        # Handling og images 
        og_img_path = os.path.join(img_out_dir, fn_rel)
        og_img = cv2.imread(fn_full)
        
        ann_new = []
        
        if adj_brightness > 0:
            for j in range(0, adj_brightness):
                fn_new = fn + "-bright-" + str(img_id_new) + ".png"
                bright_img_path = os.path.join(img_out_dir, fn_new)
                bright_img = aug_brightness(og_img)
                cv2.imwrite(bright_img_path, bright_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1
                
        if adj_contrast > 0:
            for j in range(0, adj_contrast):
                fn_new = fn + "-contrast-" + str(img_id_new) + ".png"
                contrast_img_path = os.path.join(img_out_dir, fn_new)
                contrast_img = aug_contrast(og_img)
                cv2.imwrite(contrast_img_path, contrast_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1
        
        if adj_saturation > 0:
            for j in range(0, adj_saturation):
                fn_new = fn + "-saturate-" + str(img_id_new) + ".png"
                saturate_img_path = os.path.join(img_out_dir, fn_new)
                saturate_img = aug_saturation(og_img)
                cv2.imwrite(saturate_img_path, saturate_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1

        if adj_hue > 0:
            for j in range(0, adj_hue):
                fn_new = fn + "-hue-" + str(img_id_new) + ".png"
                hue_img_path = os.path.join(img_out_dir, fn_new)
                hue_img = aug_hue(og_img)
                cv2.imwrite(hue_img_path, hue_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1

        if adj_color_jitter > 0:
            for j in range(0, adj_color_jitter):
                fn_new = fn + "-color_jitter-" + str(img_id_new) + ".png"
                color_jitter_img_path = os.path.join(img_out_dir, fn_new)
                color_jitter_img = aug_color_jitter(og_img)
                cv2.imwrite(color_jitter_img_path, color_jitter_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1
                
        if adj_posterize > 0:
            for j in range(0, adj_posterize):
                fn_new = fn + "-posterize-" + str(img_id_new) + ".png"
                posterize_img_path = os.path.join(img_out_dir, fn_new)
                posterize_img = aug_posterize(og_img)
                cv2.imwrite(posterize_img_path, posterize_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1
                
        if adj_blur > 0:
            for j in range(0, adj_blur):
                fn_new = fn + "-blur-" + str(img_id_new) + ".png"
                blur_img_path = os.path.join(img_out_dir, fn_new)
                blur_img = aug_blur(og_img)
                cv2.imwrite(blur_img_path, blur_img)
                
                img_list_master.append(createImgJson(img_id_new, "./" + fn_new))
                
                for ann in ann_og:
                    ann_new = createAnnJson(ann_id_new, img_id_new, ann["category_id"], ann["segmentation"], ann["bbox"], ann["ignore"], ann["iscrowd"], ann["area"])
                    ann_list_master.append(ann_new)
                    ann_id_new += 1
                
                img_id_new += 1
                
    json_out_obj = {
        "images": img_list_master,
        "categories": cat_list_og,
        "annotations": ann_list_master
    }

    # Writing out object
    jstr = json.dumps(json_out_obj, indent=4)
    with open(os.path.join(out_dir, "result.json"), "w+") as f:
        f.write(jstr)

### Utility Functions

In [None]:
def getXRandValFromList(input_list, x, use_only_og):
    if use_only_og:
        filtered_list = [item for item in input_list if len(item["file_name"].split("-")) <= 2]
    else:
        filtered_list = input_list

    if x > len(filtered_list):
        print("Step size too large, defaulting to the smallest category list length.")
        return filtered_list
    else:
        return random.sample(filtered_list, x)

### Global Variables

In [None]:
in_dir = "/mnt/nis_lab_research/data/coco_files/merged/far_shah_1247_v1"

In [None]:
out_dir = "/mnt/nis_lab_research/data/coco_files/aug/far_shah_1247_v1_aug_us/"

In [None]:
# Number of images each category must have before program finishes
glob_upsamp_thold = 1000

# Number of random images chosen from min val list
step_size = 20

# Setting augmentation types
# Value is number of images of each type generated
# Augmented images generated per number of images chosen
# A value of 0 is off

adj_brightness = 1
adj_contrast = 1
adj_saturation = 1
adj_hue = 1
adj_color_jitter = 9
adj_posterize = 1
adj_blur = 1

# Flag to specific if only original images can be modified 
use_only_og = True

# Speed reduction threshold
# Threshold as which step size and augmentation types above become most conservative
speed_red_thold = .9

## Main

In [None]:
if os.path.exists(out_dir):
    shutil.rmtree(out_dir)

shutil.copytree(in_dir, out_dir)

In [None]:

min_val = 0

while min_val < glob_upsamp_thold: 
    
    if min_val / glob_upsamp_thold > speed_red_thold:
        
        step_size = 5
        adj_brightness = 0
        adj_contrast = 0
        adj_saturation = 0
        adj_hue = 0
        adj_color_jitter = 1
        adj_posterize = 0
        adj_blur = 0

    with open(os.path.join(out_dir, "result.json")) as f:
        og_coco_obj = json.load(f)
        
    img_list_og = og_coco_obj["images"]
    ann_list_og = og_coco_obj["annotations"]
    cat_list_og = og_coco_obj["categories"] 

    cat_map = createCatMap(cat_list_og)
    img_ann_map = createImgAnnMap (img_list_og, ann_list_og)
    all_img_stats = calcImgStatsAll (img_ann_map)
    glob_stats = calcGlobStats (all_img_stats, cat_map)

    key_list = []
    val_list = []
    for key in glob_stats.keys():
        key_list.append(key)
        val_list.append(glob_stats[key])

    # Min none zero value
    min_val = min([x for x in val_list if x != 0])     
    min_val_key = key_list[val_list.index(min_val)]
    print(min_val_key, min_val)

    low_rep_img_id_list = []
    for ann in ann_list_og:
        if min_val_key == ann["category_id"]:
            low_rep_img_id_list.append(ann["image_id"])

    tmp_img_list = []
    for tmp_img_id in low_rep_img_id_list:        
        for img in img_list_og:
            if tmp_img_id == img["id"]:
                tmp_img_list.append(img)
                break
    
    input_imgs = getXRandValFromList (tmp_img_list, 1, use_only_og)
    
    createNewImgs(input_imgs, img_list_og, ann_list_og, cat_list_og)