In [1]:
import os
import glob
import numpy as np

import tensorflow as tf

In [2]:
from create_model import create_detection_model
from utilities import plot_detection, read_image_to_numpy

In [3]:
def obj_detect(model, input_tensor):
    """
        Args:
            input_tensor: [height, width, 3]
    """
    preprocessed_images, shapes = model.preprocess(tf.expand_dims(input_tensor, axis=0))
    prediction_dict = model.predict(preprocessed_images, shapes)
    postprocessed_dict = model.postprocess(prediction_dict, shapes)

    return postprocessed_dict

In [4]:
name_label_mappings = {'roof': 1}
# By convention, our non-background classes start counting at 1.  
category_index = {}
for name, label in name_label_mappings.items():
    category_index[label] = {'id': label, 'name': name}

num_classes = len(name_label_mappings)

label_id_offset = 1

In [5]:
load_latest_checkpoint = True


if load_latest_checkpoint:

    pipeline_config_path = 'object_detection/configs/tf2/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8.config'
    detection_model = create_detection_model(pipeline_config_path, num_classes = num_classes)
    
    checkpoint_dir = 'models/finetuned_models/ssd_resnet50_v1_fpn_640x640_coco17_tpu-8'
    
    ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir)
    ckpt.restore(latest_checkpoint)
    print("The weights of detection_model restored from : \n{}".format(latest_checkpoint))

The weights of detection_model restored from : 
None


In [6]:
image_format = '.tif'
test_image_dir = 'data/sample_data/roof/rawdata/test'
test_filenames = glob.glob(test_image_dir + '/*' + image_format)

detected_img_format = '.jpg'

for image_file in test_filenames:
    image_np = read_image_to_numpy(image_file)
    input_tensor = tf.convert_to_tensor(image_np, dtype=tf.float32)

    detections = obj_detect(model = detection_model, input_tensor=input_tensor)
    
    filename = os.path.basename(image_file)
    filename_prefix = filename.split('.tif')[0]
    filename_prefix = filename_prefix.replace(' ', '_').replace('.', '-')
    detected_file = os.path.join(test_image_dir, 'detected_'+ filename_prefix + detected_img_format)
    print("Detected image saved in {}".format(detected_file))
    plot_detection(image_np,
                   detections['detection_boxes'][0].numpy(), 
                   detections['detection_classes'][0].numpy().astype(np.int32) + label_id_offset,
                   detections['detection_scores'][0].numpy(),
                   category_index, 
                   min_score_thresh = 0.5,
                   figsize=(15, 20), 
                   image_name=detected_file)

Detected image saved in data/sample_data/roof/rawdata/test/detected_prediction_image_1-2.jpg
Detected image saved in data/sample_data/roof/rawdata/test/detected_prediction_image_1-1.jpg
