<a href="https://colab.research.google.com/github/vatj/boat-count/blob/main/boat_count.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from pathlib import Path
import cv2
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_hub as hub
import tensorflow as tf

from PIL import Image
from PIL import ImageColor
from PIL import ImageDraw
from PIL import ImageFont
from PIL import ImageOps

In [2]:
module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"

detector = hub.load(module_handle).signatures['default']

INFO:tensorflow:Saver not created because there are no variables in the graph to restore


INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [53]:
class BoatCounter:
  
  def __init__(self):
    colors = list(ImageColor.colormap.values())

    # External ressource
    self.video_path = "/content/Test-Task Sequence from Wörthersee.mp4"
    self.module_handle = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
    
    # Hyper parameter
    self.confidence_threshold = 0.4
    self.iou_threshold = 0.6
    self.max_area = 300 * 500

    # Helpers
    self.boat_count = 0
    self.frame_count = 0
    self.last_object_label = 0
    self.boxed_frames = []

    # Graphics 
    self.color = colors[1]
    self.thickness = 10

    # Tracker
    self.tracked_objects = dict()
    self.object_ious = []
    self.new_objects = []
    self.matched_iou = dict()
    
    # Initialisation
    self.load_video()
    self.load_model()

  def load_model(self):
    # self.detector = hub.load(self.module_handle).signatures['default']
    self.detector = detector

  def load_video(self):
    self.cap = cv2.VideoCapture(self.video_path)
    self.width  = self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)  
    self.height = self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)

    # Prepare writer for frame with bounding box 
    fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
    self.video_writer = cv2.VideoWriter(
        '/content/test.mp4', 
        fourcc, 
        float(self.cap.get(cv2.CAP_PROP_FPS)), 
        (int(self.width), int(self.height))
        )

  def main(self):

    while self.cap.isOpened():
      print(f"frame {self.frame_count}")
      _ , frame = self.cap.read()
      self.current_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
      self.run_single_inference()
      self.postprocess_inference()
      self.draw_all_boxes()
      self.save_box_frames_to_video()

      if self.frame_count > 0:
        self.run_all_IOU_trackers()
        self.resolve_tracking_conflicts()
      
      self.assign_new_labels()

      if self.frame_count > 100:
        break
    
      self.frame_count += 1
    
    self.boat_count = self.last_object_label
    print(f"Boat count so far is {self.boat_count}")
    self.video_writer.release()
    

  def run_single_inference(self):
    converted_frame = self.preprocess_frame_for_inference()
    self.current_output = self.detector(converted_frame)

  def preprocess_frame_for_inference(self):
    return tf.image.convert_image_dtype(tf.convert_to_tensor(self.current_frame), dtype=tf.float32)[tf.newaxis, ...]

  def postprocess_inference(self):
    """Post-process detector inference by keeping only boats bounding boxes with high confidence."""
    numpy_outputs = {key: value.numpy() for key, value in self.current_output.items()}
    self.postprocess_outputs = []

    for index, key in enumerate(numpy_outputs["detection_class_entities"]):
      # Filter for boats over the confidence threshold
      if ((key == b"Boat") and (numpy_outputs["detection_scores"][index] > self.confidence_threshold)):
        # Check for size to exclude the driving boat
        if self.check_area(numpy_outputs["detection_boxes"][index]):
          self.postprocess_outputs.append(numpy_outputs["detection_boxes"][index])

    print(f"Postprocess outputs on frame {self.frame_count} : {self.postprocess_outputs}")

  def check_area(self, box_coord):
     """Fake filter to dismiss the boat on which camera is running"""
     ymin, xmin, ymax, xmax = tuple(box_coord)
     area = (xmax - xmin) * self.width * (ymax - ymin) * self.height

     if area < self.max_area:
       return True
     else:
       return False


  def draw_bounding_box_on_image(self, image, ymin, xmin, ymax, xmax):
    """Adds a bounding box to an image."""
    draw = ImageDraw.Draw(image)
    (left, right, top, bottom) = (xmin * self.width, xmax * self.width,
                                  ymin * self.height, ymax * self.height)
    draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
              (left, top)],
              width=self.thickness,
              fill=self.color)

  def draw_all_boxes(self):
    """Overlay labeled boxes on an image with formatted scores and label names."""

    for box_coord in self.postprocess_outputs:
      ymin, xmin, ymax, xmax = tuple(box_coord)
      image_pil = Image.fromarray(np.uint8(self.current_frame)).convert("RGB")
      self.draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax)
      np.copyto(self.current_frame, np.array(image_pil))


  def run_single_IOU_tracker(self, index, current_object_coord):
    """Find ious with all currently tracked object. Assigns to new_objects if no
    overlap."""
    self.object_ious.append(dict())
    max_iou = 0

    for key, box_coord in self.tracked_objects.items():
      iou = self.compute_IOU(current_object_coord, box_coord)
      max_iou = max(max_iou, iou)
      
      self.object_ious[-1][key] = iou

      # Potential match, will have to be compared to other potential match
      if iou > self.iou_threshold:
        self.matched_iou[key].append((index, iou))

    # If max_iou under threshold then it is a new object
    if max_iou < self.iou_threshold:
      self.new_objects.append(index)


  def compute_IOU(self, box1, box2):
    """Compute Intersection over Union between 2 bounding boxes"""
    y_min1, x_min1, y_max1, x_max1 = tuple(box1)
    y_min2, x_min2, y_max2, x_max2 = tuple(box2)
    # intersection box coords 
    x_min_inter, y_min_inter = max(x_min1, x_min2), max(y_min1, y_min2)
    x_max_inter, y_max_inter = min(x_max1, x_max2), min(y_max1, y_max2)

    inter_area = (x_max_inter - x_min_inter) * (y_max_inter - y_min_inter)
    area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
    area2 = (x_max2 - x_min2) * (y_max2 - y_min2)

    return inter_area / float(area1 + area2 - inter_area)

  def resolve_tracking_conflicts(self):
    """If several new objects have overlap with tracked object, assigns the old 
    label to the new object with max iou. Add objects with no match to new_objects."""
    # Note that I am aware of the potential issues but I don't want to implement 
    # something more complicated from scratch here
    
    already_matched = []
    unmatched = set()
    # Creating a new tracked dict will ensure we don't keep tracking old object with no match
    self.tracked_objects = dict()

    for key, matched_list in self.matched_iou.items():
      # sort matches by iou
      iou_sorted = sorted(matched_list, key=lambda x: x[1]) 
      
      for index, (potential_winner, _) in enumerate(iou_sorted):
        if potential_winner in already_matched:
          continue
        else:
          # Use old label to keep track of new object
          self.tracked_objects[key] = self.postprocess_outputs[potential_winner]
          # Make sure you don't match this object again
          already_matched.append(potential_winner)
          # Remove from the unmatched if exist
          unmatched.discard(potential_winner)
          # Add all non-matched to set of unmatch
          for index2, _ in iou_sorted[index:]:
            if index2 not in already_matched:
              unmatched.add(index2)

    
    self.new_objects.extend(list(unmatched))
    self.object_ious.clear()

  def assign_new_labels(self):
    """Assigns labels to new objects."""

    if self.frame_count == 0:
      for bounding_box in self.postprocess_outputs:
        self.tracked_objects[str(self.last_object_label)] = bounding_box
        self.last_object_label += 1
    else:
      for index in self.new_objects:
        self.tracked_objects[str(self.last_object_label)] = self.postprocess_outputs[index]
        self.last_object_label += 1

      self.new_objects.clear()
    print(f"tracked_objects on frame {self.frame_count} : {self.tracked_objects}")

  def run_all_IOU_trackers(self):

    self.matched_iou = dict()
    for key in self.tracked_objects.keys():
      self.matched_iou[key] = []

    for index, box_coord in enumerate(self.postprocess_outputs):
      self.run_single_IOU_tracker(index, box_coord)

    print(f"new_objects on frame {self.frame_count} : {self.new_objects}")
    

  def save_box_frames_to_video(self):
    self.video_writer.write(cv2.cvtColor(self.current_frame, cv2.COLOR_RGB2BGR))

In [54]:
my_boat_counter = BoatCounter()

In [55]:
my_boat_counter.main()

frame 0
Postprocess outputs on frame 0 : [array([0.39550692, 0.50152516, 0.95147216, 0.6990435 ], dtype=float32), array([0.32380667, 0.4695601 , 0.39550516, 0.49616212], dtype=float32)]
tracked_objects on frame 0 : {'0': array([0.39550692, 0.50152516, 0.95147216, 0.6990435 ], dtype=float32), '1': array([0.32380667, 0.4695601 , 0.39550516, 0.49616212], dtype=float32)}
frame 1
Postprocess outputs on frame 1 : [array([0.40170017, 0.4929987 , 0.95041835, 0.69908595], dtype=float32)]
new_objects on frame 1 : []
tracked_objects on frame 1 : {'0': array([0.40170017, 0.4929987 , 0.95041835, 0.69908595], dtype=float32)}
frame 2
Postprocess outputs on frame 2 : [array([0.39744034, 0.50335056, 0.9563618 , 0.69912595], dtype=float32)]
new_objects on frame 2 : []
tracked_objects on frame 2 : {'0': array([0.39744034, 0.50335056, 0.9563618 , 0.69912595], dtype=float32)}
frame 3
Postprocess outputs on frame 3 : [array([0.3771881 , 0.4831279 , 0.9429524 , 0.69843453], dtype=float32)]
new_objects on fra

In [None]:
fig, ax = plt.subplots(1,1, figsize=(12, 8))

ax.imshow(my_boat_counter.current_frame)

In [None]:
my_boat_counter.