In [1]:
import tensorflow as tf
import tensorflow_hub as hub
import tempfile
from six.moves.urllib.request import urlopen
from six import BytesIO

import numpy as np
from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
from PIL import ImageOps

import time
import math
import random
import numpy as np
import sklearn
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, layers, models
%matplotlib inline

In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'
tf.test.is_gpu_available()

True

# Helper functions

In [3]:
def display_image(image):
    fig = plt.figure(figsize=(20, 15))
    plt.grid(False)
    plt.imshow(image)

def draw_bounding_box_on_image(image,
                               ymin, xmin,
                               ymax, xmax,
                               color, font, thickness=4, display_str_list=()):
    """Adds a bounding box to an image."""
    draw = ImageDraw.Draw(image)
    im_width, im_height = image.size
    (left, right, top, bottom) = (xmin * im_width, xmax * im_width,
                                ymin * im_height, ymax * im_height)
    draw.line([(left, top), (left, bottom), (right, bottom), (right, top), (left, top)], width=thickness, fill=color)
    display_str_heights = [font.getsize(ds)[1] for ds in display_str_list]
    total_display_str_height = (1 + 2 * 0.05) * sum(display_str_heights)
    if top > total_display_str_height:
        text_bottom = top
    else:
        text_bottom = bottom + total_display_str_height
    for display_str in display_str_list[::-1]:
        text_width, text_height = font.getsize(display_str)
        margin = np.ceil(0.05 * text_height)
        draw.rectangle([(left, text_bottom - text_height - 2 * margin),
                    (left + text_width, text_bottom)],
                   fill=color)
    draw.text((left + margin, text_bottom - text_height - margin), display_str, fill="black", font=font)
    text_bottom -= text_height - 2 * margin

def draw_boxes(image, boxes, class_names, scores, max_boxes=10, min_score=0.1):
    """Overlay labeled boxes on an image with formatted scores and label names."""
    colors = list(ImageColor.colormap.values())
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationSansNarrow-Regular.ttf", 25)
    except IOError:
        print("Font not found, using default font.")
        font = ImageFont.load_default()

    for i in range(min(boxes.shape[0], max_boxes)):
        if scores[i] >= min_score:
            ymin, xmin, ymax, xmax = tuple(boxes[i])
            display_str = "{}: {}%".format(np.array(class_names[i]), int(100 * scores[i]))
            color = colors[hash(class_names[i]) % len(colors)]
            image_pil = Image.fromarray(np.uint8(image)).convert("RGB")
            draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, font, display_str_list=[display_str])
            np.copyto(image, np.array(image_pil))
    return image

In [4]:
module_handle = "pretrained_model"
detector = hub.load(module_handle).signatures['default']

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [5]:
def run_detector(detector, img: np.array):
    result = detector(img)
    result = {key:value.numpy() for key,value in result.items()}
    # print("Found %d objects." % len(result["detection_scores"]))
    # image_with_boxes = draw_boxes(np.asarray(img), result["detection_boxes"], result["detection_class_entities"], result["detection_scores"])
    # display_image(image_with_boxes)
    return result['detection_boxes'][0:5], result["detection_class_entities"][0:5], result["detection_scores"][0:5]

# Load all images file paths

In [9]:
data_path = '/home/tjy/data/china-birds-images'
import os

bird_file_map = {}

# return array of bird names
birdList = sorted(os.listdir(data_path))
n_birds = len(birdList)
# n_birds = 10 # use ten kinds of birds first
loadedImages = []

n_images = 0
n_birds_loaded = 0
for b in birdList:
    if n_birds_loaded >= n_birds:
        break
    print("Loading images for '" + b + "'")
    curdir = os.path.join(data_path, b)
    if not os.path.isdir(curdir):
        continue
    img_files = os.listdir(curdir)
    
    filenames = [os.path.join(curdir, f) for f in img_files]
    n_f = len(filenames)
    if n_f == 0:
        print("no data for '" + b + "', skipping...")
        continue
    bird_file_map[b] = filenames
    print(n_f, "images loaded for '" + b + "'")
    n_birds_loaded += 1
    n_images += n_f

Loading images for 'Aberrant Bush-Warbler'
302 images loaded for 'Aberrant Bush-Warbler'
Loading images for 'Ala Shan Redstart'
470 images loaded for 'Ala Shan Redstart'
Loading images for 'Aleutian Tern'
316 images loaded for 'Aleutian Tern'
Loading images for 'Altai Snowcock'
457 images loaded for 'Altai Snowcock'
Loading images for 'American Wigeon'
500 images loaded for 'American Wigeon'
Loading images for 'Amur Falcon'
no data for 'Amur Falcon', skipping...
Loading images for 'Arctic Warbler'
353 images loaded for 'Arctic Warbler'
Loading images for 'Ashy Bulbul'
340 images loaded for 'Ashy Bulbul'
Loading images for 'Ashy Drongo'
455 images loaded for 'Ashy Drongo'
Loading images for 'Ashy Minivet'
499 images loaded for 'Ashy Minivet'
Loading images for 'Ashy Wood Pigeon'
375 images loaded for 'Ashy Wood Pigeon'
Loading images for 'Ashy Woodswallow'
381 images loaded for 'Ashy Woodswallow'
Loading images for 'Ashy-throated Parrotbill'
297 images loaded for 'Ashy-throated Parrotbi

# Crop to bounding box if containing a bird

In [None]:
import math
import os
output_dir = 'data-filtered'
bird_category_names = ['Magpie','Bird','Woodpecker', 'Blue jay','Spatula','Ostrich','Raven','Owl','Duck', 'Goose', 'Swan', 'Falcon', 'Sparrow']
orig_filter_map = {}

def crop_and_save(im, box, output_path):
    ymin = math.floor(box[0] * im.shape[1])
    xmin = math.floor(box[1] * im.shape[2])
    ymax = math.floor(box[2] * im.shape[1])
    xmax = math.floor(box[3] * im.shape[2])
    cropped_image = tf.image.crop_to_bounding_box(im, ymin, xmin, ymax - ymin, xmax - xmin)
    # display_image(cropped_image[0])
    # tensorflow encode_jpeg only support uint8 type, so convert 0 -> 1.0 to 0 -> 255
    cropped_image = tf.image.convert_image_dtype(cropped_image[0], tf.uint8)
    encoded = tf.io.encode_jpeg(cropped_image)
    tf.io.write_file(output_path, encoded)
    
for bird in bird_file_map.keys():
    im_no = 0
    for file in bird_file_map[bird]:
        try:
            im = tf.io.read_file(file)
            im = tf.io.decode_image(im, channels=3, expand_animations=False)
            im = tf.image.convert_image_dtype(im, tf.float32)[tf.newaxis, ...]
        except:
            print("skipping", file)
            continue

        output_path = os.path.join(output_dir, bird)
        if os.path.exists(output_path):
            print("skipping already filtered data at:", output_path)
            continue
        os.makedirs(output_path, exist_ok=True)
        
        boxes, class_names, _ = run_detector(detector, im)
        b = boxes[0]
        c = class_names[0]
        
        # bird_category_names are classes that are birds, extracted from categories in openimagev4
        if c.decode('ascii') not in bird_category_names: # class names are encoded to bytes
            continue
        file_path = '{0}_{1}.jpeg'.format(bird, im_no)
        
        # remove if exists, because tensorflow.io.write_file() seems to create a copy instead of overwrite
        if os.path.exists(file_path):
            os.remove(file_path)
        crop_and_save(im, b, os.path.join(output_path, file_path))
        orig_filter_map[file] = file_path # remember original file paths
        im_no += 1
    print('{0} images for {1}'.format(im_no, bird))

In [None]:
import json
with open('orig2filtered.json', 'w') as f:
    json.dump(orig_filter_map, f)