# Helper functions

In [2]:
import cv2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as mcolors
import math
import os


In [19]:
objects = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase',
'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog',
'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse',
'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']

objects_dict = {}
for idx, val in enumerate(objects):
    objects_dict[idx] = val

In [1]:
def check():
    print("loaded helper functions")
    

In [17]:
def getColor(i):
    colors = list(mcolors.TABLEAU_COLORS)
    i = i%len(colors)
    return np.array(mcolors.to_rgb(list(mcolors.TABLEAU_COLORS)[i]))

In [2]:
def getFrame(video_file, frame_no):
    cap = cv2.VideoCapture(video_file)
    frameCount = 0
    while(cap.isOpened()):
        ret, frame = cap.read()
        if(ret):
            if(frameCount == frame_no):
                cap.release()
                cv2.destroyAllWindows()
                return frame
        frameCount+=1
    cap.release()
    cv2.destroyAllWindows()

In [10]:
def bboxVis(df,img,thickness=1,grid = 'off'):
    red = (255,0,0)
    green = (0,0,255)
    im_dots_ = img.copy()
    
    test_coords = df[[
#         "centroidX",
#         "centroidY",
        "bbox_left",
        "bbox_top",
        "bbox_w",
        "bbox_h",
        "class"
    ]].values.tolist()
    for l,t,w,h,c in test_coords:
        color = red if c == 0 else green
#         im_dots_ = cv2.circle(im_dots_, (int(x),int(y)), radius=1, color=red, thickness=-1)
        start = (l,t)
        end = (l+w, t+h)
        im_bbox_ = cv2.rectangle(im_dots_, start,end, color=color, thickness=thickness)

    
    if grid == 'on':
        plt.axis('on')
        plt.grid(True)

        # Show the minor grid
        plt.minorticks_on()
        plt.grid(which='minor', linestyle='--', linewidth=0.5, color='gray', alpha=0.5)

    plt.imshow(im_dots_)

In [14]:
def bboxVisGroups(df, bg_img, group_by = 'class', groups = None, thickness = 1):

    if groups == None:
        groups = df[group_by].unique()
    
    plt.clf()
    cols = 4
    rows = math.ceil(len(groups)/4)
    n_axes = cols * rows
    (h,w,p) = bg_img.shape
    img_ratio = h/w

    

    s = 3
    fig, ax = plt.subplots(rows,cols, figsize=(cols*s, rows*s*img_ratio))
    colors = list(mcolors.TABLEAU_COLORS)

                
            
    for i,group in enumerate(groups):

        col = i % cols
        row = int((i - col)/cols)
        c = getColor(i)*255
        c = np.append(c, 255)
        
        if len(groups) < 4:
            sel_ax = ax[col]
        else:
            sel_ax = ax[row, col] 
            


        

        dots_object = df[(df['class'] == group)]
        

        test_coords = dots_object[[
    #         "centroidX",
    #         "centroidY",
            "bbox_left",
            "bbox_top",
            "bbox_w",
            "bbox_h"
        ]].values.tolist()




        im_dots = bg_img.copy()
#         print(im_dots)

        for l,t,w,h in test_coords:
    #         im_dots_ = cv2.circle(im_dots_, (int(x),int(y)), radius=1, color=red, thickness=-1)
            start = (l,t)
            end = (l+w, t+h)
            cv2.rectangle(im_dots, start,end, color=c, thickness=thickness)



            
        
            
        bg_img_= cv2.cvtColor(bg_img.copy(), cv2.COLOR_RGB2RGBA)


        sel_ax.imshow(bg_img_)
        sel_ax.imshow(im_dots)
        sel_ax.set_title(label=f'{objects_dict[group]}',loc='left',fontsize='small')

#     fig.tight_layout()


    
    # remove axis for all
    for row in range(rows):
        for col in range(cols):
            i = col + (row * cols)
            if len(groups) < 4:
                sel_ax = ax[col]
            else:
                sel_ax = ax[row, col] 
            sel_ax.axis('off')
            
    plt.show()

In [6]:
def view_objects(uids, df, maxCount =100, interval = 5):
    view_uids_df = df_In.loc[df_In['uid'].isin(uids)]
    view_uids_df['timeStamp_s'] = view_uids_df['timeStamp'].dt.round("s")
    view_uids_df_grouped = view_uids_df.groupby(['uid','timeStamp_s']).first().reset_index()
    j = 0
    for i,row in view_uids_df_grouped.iterrows():
        if j > maxCount:
            break

        if i%interval != 0:
            continue
        video_file = (f'{videos}/{row["video"]}.mp4')
        frame_no = row["frame"]

        print(row[['uid']])


        frame = getFrame(video_file, frame_no)
        fig, ax = plt.subplots()
        ax.imshow(frame)
        rect = patches.Rectangle((row['bbox_left'], row['bbox_top']), row['bbox_w'], row['bbox_h'], linewidth=1, edgecolor='r', facecolor='none')
        plt.axis('off')
                # Add the patch to the Axes
        ax.add_patch(rect)
    #     plt.savefig(filename, bbox_inches='tight')
        plt.show()
        j +=1

In [1]:
def view_objects_grouped(uids, df, maxCount =10, interval = 5, save_loc = None):
    view_uids_df = df_In.loc[df_In['uid'].isin(uids)]
    
    view_uids_df['timeStamp_s'] = view_uids_df['timeStamp'].dt.round(f'{interval}s')
    view_uids_df_grouped = view_uids_df.groupby(['timeStamp_s'])
    j = 0
    for time, group in view_uids_df_grouped:
        j +=1

        if j > maxCount:
            break

        video = (group["video"].tolist()[0])
        video_file = (f'{videos}/{video}.mp4')
        frame_no = group['frame'].value_counts()[:1].index.tolist()[0]
        


#         print(row[['uid']])


        frame = getFrame(video_file, frame_no)
        fig, ax = plt.subplots()
        ax.imshow(frame)
        
        for i,row in (group[group['frame']== frame_no]).iterrows():
            c = 'r' if row['class'] == 0 else 'g'
            rect = patches.Rectangle((row['bbox_left'], row['bbox_top']), row['bbox_w'], row['bbox_h'], linewidth=1, edgecolor=c, facecolor='none')
         # Add the patch to the Axes
            ax.add_patch(rect)

        plt.axis('off')
        
        if(save_loc != None):
            if not (os.path.exists(save_loc)):
                os.makedirs(save_loc)
                
            plt.savefig(f'{save_loc}/{time}.png', bbox_inches='tight')
            plt.close(fig)
        else:
            plt.show()

In [12]:
def flat_column_names(df_In):
    df = df_In.copy()
    df.columns = df.columns.to_flat_index()
    df.columns = ['_'.join(col) for col in df.columns.values]
    df = df.reset_index()
    return df