Load Stalls

In [8]:
import os

import yaml

def load_stalls(cam_label, dtype_float=False, version="detection", weight=(1, 1)):
    """ loads the stall locations (in pixels) for a given camera label by loading the
    stall settings in settings/<cam_label>.yaml and converting the float stall locations
    to pixels
    
    the version kwarg can be 'detection', 'center', or 'average' to specify which type of stall location
    is returned.
    * 'detection' returns the detection point, which is the point at the centerish of the yolo car detections
    * 'center' returns the center of the stall, which is calculated as the intersection of the vertices of the stall lines
    * 'average' returns the average of the detection and center points
    ! the algorithm that matches detections to stalls basically computes the distance between each point and a point in the
        yolo bbox. The point in the yolo bbox is the point in between the center of the bbox and the middle of the bottom edge
        of the bbox. Because of this, the detection reference point is closest to the average between the stall center and
        the center of the yolo bbox. Sometimes the stall location isn't actually inside the yolo detection if the camera is at
        a lower angle, so the 'center' version parameter can sometimes be inaccurate.
    ! this average is a weighted average (3*detection_loc + 1*stall_center_loc) / 4
    
    Example:
    load_stalls('000004-4-cb_northeast_north')
    {
        "004-0003": (1856, 796),
        "004-0004": (1839, 703),
        ...
    }
    load_stalls('000004-4-cb_northeast_north', dtype_float=True)
    {
        "004-0003": (0.96707, 0.73705),
        "004-0004": (0.95783, 0.65163),
        ...
    }
    """
    if version.lower() not in ('detection', 'center', 'average'):
        raise ValueError(f"Invalid version {version}. Must be 'detection', 'center', or 'average'")
    
    settings_path = f"../zoning_data/{cam_label}.yaml"
    with open(settings_path, 'r') as f:
        settings = yaml.safe_load(f)
    height = settings['resolution']['height']
    width = settings['resolution']['width']
    stalls = settings['stalls']
    output_stalls_dict = {}
    for spotID, stall_entry in stalls.items():
        x, y = stall_entry['point']
        stall_x, stall_y = stall_entry['stall_center']
        if not dtype_float:
            x_pix = int(x * width)
            y_pix = int(y * height)
            stall_x_pix = int(stall_x * width)
            stall_y_pix = int(stall_y * height)
            if version == 'detection':
                res = (x_pix, y_pix)
            elif version == 'center':
                res = (stall_x_pix, stall_y_pix)
            elif version == 'average':
                res = (
                    int((x_pix*weight[0] + stall_x_pix*weight[1]) // sum(weight)), 
                    int((y_pix*weight[0] + stall_y_pix*weight[1]) // sum(weight)))
        else:
            if version == 'detection':
                res = (round(float(x), 5), round(float(y), 5))
            elif version == 'center':
                res = (round(float(stall_x), 5), round(float(stall_y), 5))
            elif version == 'average':
                res = (
                    round((float(x)*weight[0] + float(stall_x)*weight[1]) / sum(weight), 5), 
                    round((float(y)*weight[0] + float(stall_y)*weight[1]) / sum(weight), 5))
        output_stalls_dict[spotID] = res
    return output_stalls_dict

def get_stalls():
    # load the settings yaml file
    cam_label = "000004-4-cb_northeast_north"
    stalls = load_stalls(cam_label, dtype_float=True, version='average')
    # settings_path = f"settings/{cam_label}.yaml"
    # with open(settings_path, "r") as f:
    #     settings = yaml.safe_load(f)
    # # create a dictionary of the stall locations
    # stalls = {spotID: e['point'] for spotID, e in settings['stalls'].items()}
    for spotID, point in stalls.items():
        # shift the stalls since we crop the camera stream down to just the right half
        x, y = point[0], point[1]
        x_shifted, y_shifted = (x - 0.5) / 0.5, y
        x_pix, y_pix = int(x_shifted * 1920//2), int(y_shifted * 1080)
        stalls[spotID] = (x_pix, y_pix)
    return stalls

show video

In [4]:
from time import sleep
from IPython.display import clear_output

import cv2
import matplotlib.pyplot as plt

SPOT_YELLOW_RGB = (255, 227, 116) # spot YELLOW in RGB
SPOT_YELLOW_HEX = "#ffd92e"
SPOT_YELLOW_BGR = (116, 227, 255) # spot YELLOW in BGR

def show_video(video_path, track_record, show_stalls=True, show_tracks=False, only_this_track_id=None, 
               title_string=None, skip_frames=None, plt_show=True, frame_delay=None, pause=True, track_id_col='track_id'):
    clear_output(wait=True)  if plt_show else None # Clear the previous frame
    # is showing the stalls enabled
    if show_stalls:
        # load the stalls from the YAML file
        stalls = get_stalls()
        # function that adds the stall location as a dot on the frame
        # and also adds the stall number as text
        def add_stalls_to_frame(frame, stalls):
            for spotID, (x, y) in stalls.items():
                cv2.circle(frame, (int(x), int(y)), 5, (30, 130, 250), -1)
                cv2.putText(
                    frame, 
                    str(int(spotID.replace("004-", ""))), 
                    (int(x), int(y)), 
                    cv2.FONT_HERSHEY_SIMPLEX, 
                    0.8, 
                    (50, 200, 50), 
                    2
                )
            return frame
        
    # if showing the tracks is enabled
    if show_tracks:
        # load the track record
        main_video_part = video_path.split("/")[-1]
        if track_record is None:
            print(f"Track record does not exist for this video {video_path}")
            return False
        # if only wanting to observe the main track of the video, filter the track_record to just row entries with that track_id
        if only_this_track_id is not None:
            if isinstance(only_this_track_id, str) or isinstance(only_this_track_id, float) or isinstance(only_this_track_id, int):
                only_this_track_id = int(only_this_track_id)
                track_record = track_record.loc[track_record[track_id_col].astype(int) == only_this_track_id, :].copy()
            elif isinstance(only_this_track_id, list) or isinstance(only_this_track_id, tuple) or isinstance(only_this_track_id, set):
                only_this_track_id = set(map(int, only_this_track_id))
                track_record = track_record.loc[track_record[track_id_col].astype(int).isin(only_this_track_id), :].copy()
        # create a dictionary where keys are iteration numbers and the values are a numpy array
        # with shape(n_tracks, 5) where each entry is a track_id, tl_x, tl_y, br_x, br_y
        frame_tracks = {}
        track_record['br_x'] = track_record['tl_x'] + track_record['w']
        track_record['br_y'] = track_record['tl_y'] + track_record['h']
        for iteration, frame in track_record.groupby('iteration'):
            frame = frame.loc[frame['confirmed'], :]
            bbox_df = frame[[track_id_col, 'confidence', 'tl_x', 'tl_y', 'br_x', 'br_y']]
            bbox_arr = bbox_df.to_numpy().astype(float)
            frame_tracks[int(iteration)-1] = bbox_arr
        # function that adds the track bounding boxes to the frame with a track_id label above the top left corner
        def add_tracks_to_frame(frame, tracks):
            for track_id, confidence, tl_x, tl_y, br_x, br_y in frame_tracks[frame_number]:
                track_id, tl_x, tl_y, br_x, br_y = map(int, (track_id, tl_x, tl_y, br_x, br_y))
                cv2.rectangle(frame, (tl_x, tl_y), (br_x, br_y), SPOT_YELLOW_BGR, 2)
                text = f"{track_id} ({confidence:.2f})"
                cv2.putText(
                    frame, 
                    text, 
                    (tl_x, tl_y), 
                    cv2.FONT_HERSHEY_SIMPLEX, 
                    0.8, 
                    (200, 50, 50), 
                    2
                )
    
    # loop through every frame of the video
    frame_number = 0
    cap = cv2.VideoCapture(video_path)
    while cap.isOpened():
        ret, frame = cap.read()
        if skip_frames is not None and frame_number % skip_frames != 0:
            frame_number += 1
            continue
        if not ret:
            input("Press Enter to end...") if pause else None
            break
        if show_stalls:
            frame = add_stalls_to_frame(frame, stalls)
        if show_tracks:
            if frame_number in frame_tracks.keys():
               add_tracks_to_frame(frame, frame_tracks[frame_number])
        # Convert frame from BGR to RGB
        
        if plt_show:
            plt.title(title_string) if title_string is not None else None
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            plt.imshow(frame)
            plt.axis('off')
            plt.gcf().set_size_inches(15, 8)
            plt.show()
            if frame_delay is not None:
                sleep(frame_delay)
            clear_output(wait=True)  # Clear the previous frame
        else:
            cv2.imshow(title_string, frame)    
            if frame_delay is not None:
                sleep(frame_delay)
            if cv2.waitKey(1) & 0xFF == ord('q') or cv2.getWindowProperty(title_string, cv2.WND_PROP_VISIBLE) < 1:
                break
            
            if frame_number <= 1:
                input("Press Enter to continue...") if pause else None
        
        frame_number += 1
    cap.release()

example uses

In [11]:
from datetime import datetime, timezone

import pandas as pd

version_name = "base"
video_name = "rec-2024-12-20-17-36-45_small.mp4"

record = pd.read_csv("../zoning_data/rec-2024-12-20/CSVs/rec-2024-12-20-17-36-45_small.csv")

date_str = video_name.replace("rec-", "").replace("_small.mp4", "")
date = datetime.strptime(date_str, "%Y-%m-%d-%H-%M-%S").replace(tzinfo=timezone.utc)
video_path = "../zoning_data/rec-2024-12-20/MP4s/rec-2024-12-20-17-36-45_small.mp4"
show_video(video_path, record, show_stalls=True, show_tracks=True, only_this_track_id=40, frame_delay=0.2, skip_frames=None, plt_show=False, pause=True)