# 0. Setup

In [1]:
import os

In [2]:
CUSTOM_MODEL_NAME = 'my_ssd_mobnet'
LABEL_MAP_NAME = 'label_map.pbtxt'

In [3]:
paths = {
    'APIMODEL_PATH': os.path.join('models'),
    'ANNOTATION_PATH': os.path.join('workspace','annotations'),
    'IMAGE_PATH': os.path.join('workspace','images'),
    'CHECKPOINT_PATH': os.path.join('workspace','models',CUSTOM_MODEL_NAME)
 }

In [4]:
files = {
    'PIPELINE_CONFIG':os.path.join('workspace','models', CUSTOM_MODEL_NAME, 'pipeline.config'),
    'LABELMAP': os.path.join(paths['ANNOTATION_PATH'], LABEL_MAP_NAME)
}

# 1. Model evaluation

In [None]:
EVALUATION_SCRIPT = os.path.join(paths['APIMODEL_PATH'], 'research', 'object_detection', 'model_main_tf2.py')

In [None]:
command = "python {} --model_dir={} --pipeline_config_path={} --checkpoint_dir={}".format(EVALUATION_SCRIPT, paths['CHECKPOINT_PATH'],files['PIPELINE_CONFIG'], paths['CHECKPOINT_PATH'])

In [None]:
print(command)

In [None]:
!{command}

In [None]:
# See tensorboard results
# Go to the eval directory and use the following command, then visit localhost:6006
command = "tensorboard --logdir=."
print(command)

# 2. Load the train model from the checkpoint

In [None]:
import os
import tensorflow as tf
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils as viz_utils
from object_detection.builders import model_builder
from object_detection.utils import config_util

In [None]:
# Build a detection model
configs = config_util.get_configs_from_pipeline_file(files['PIPELINE_CONFIG'])
detection_model = model_builder.build(model_config=configs['model'], is_training=False)

# Restore checkpoint
checkpoint_index=91 # set the checkpoint index
ckpt = tf.compat.v2.train.Checkpoint(model=detection_model)
ckpt.restore(os.path.join(paths['CHECKPOINT_PATH'], 'ckpt-'+str(checkpoint_index))).expect_partial()

@tf.function
def detect_fn(image):
    image, shapes = detection_model.preprocess(image)
    prediction_dict = detection_model.predict(image, shapes)
    detections = detection_model.postprocess(prediction_dict, shapes)
    return detections

# 3. Perform prediction on an image

In [None]:
!pip install opencv-python-headless

In [None]:
! pip install opencv-python
import cv2
import numpy as np
from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
category_index = label_map_util.create_category_index_from_labelmap(files['LABELMAP'])

In [None]:
# loop through all the test set and perform the prediction
for image in os.listdir(os.path.join(paths['IMAGE_PATH'],'test')):
    if not image.endswith(".jpg"):
        continue
    img = cv2.imread(os.path.join(paths['IMAGE_PATH'],'test',image))
    image_np = np.array(img)
    input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
    try:
        detections = detect_fn(input_tensor)
    except:
        print("Exception")
        continue
    num_detections = int(detections.pop('num_detections'))
    detections = {key: value[0, :num_detections].numpy()
                  for key, value in detections.items()}
    detections['num_detections'] = num_detections
    detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

    label_id_offset = 1
    image_np_with_detections = image_np.copy()

    # set the proper arguments to change results
    viz_utils.visualize_boxes_and_labels_on_image_array(
                image_np_with_detections,
                detections['detection_boxes'],
                detections['detection_classes']+label_id_offset,
                detections['detection_scores'],
                category_index,
                use_normalized_coordinates=True,
                max_boxes_to_draw=1,
                min_score_thresh=.6,
                agnostic_mode=False)

    plt.imshow(cv2.cvtColor(image_np_with_detections, cv2.COLOR_BGR2RGB))
    plt.show()
    key = cv2.waitKey(3000)
    if key == 27: #ESC to exit loop
        cv2.destroyAllWindows()
        break

# 4. Detection on test videos (Find galaxy videos and test it)

In [None]:
import cv2
import datetime
video_path='workspace/raw_data/'
for video in os.listdir(video_path):
    print("Processing video", video)
    vcapture = cv2.VideoCapture(video_path+video)
    width = int(vcapture.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(vcapture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = vcapture.get(cv2.CAP_PROP_FPS)
    file_name = "splash_{:%Y%m%dT%H%M%S}.avi".format(datetime.datetime.now())
    vwriter = cv2.VideoWriter(file_name, cv2.VideoWriter_fourcc(*'MJPG'), fps, (width, height))
    count = 0
    success = True
    while success:
        success, img = vcapture.read()
        if success:
            count += 1
            image_np = np.array(img)
            input_tensor = tf.convert_to_tensor(np.expand_dims(image_np, 0), dtype=tf.float32)
            detections = detect_fn(input_tensor)
            num_detections = int(detections.pop('num_detections'))
            detections = {key: value[0, :num_detections].numpy()
                          for key, value in detections.items()}
            detections['num_detections'] = num_detections
            detections['detection_classes'] = detections['detection_classes'].astype(np.int64)

            label_id_offset = 1
            image_np_with_detections = image_np.copy()
            viz_utils.visualize_boxes_and_labels_on_image_array(
                        image_np_with_detections,
                        detections['detection_boxes'],
                        detections['detection_classes']+label_id_offset,
                        detections['detection_scores'],
                        category_index,
                        use_normalized_coordinates=True,
                        max_boxes_to_draw=1,
                        min_score_thresh=.95,
                        agnostic_mode=False)
            vwriter.write(image_np_with_detections)
            if count > 250:
                break
    vwriter.release()
    print("Saved to ", file_name)