This notebook is used to run model over test data and generate submission file.

Define main classes

In [1]:
# Copyright 2021 Joao Phellip de Mello @joaophellip
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import List, Tuple
import pandas as pd
import numpy as np
import cv2
from sklearn.cluster import KMeans
from skimage import color


"""
Main module. Contains the implementation of classes that make predictions over videos and, in turn, frames.
"""


class VideoPredictor:

    def __init__(self,
                 helmets: pd.DataFrame,
                 tracking: pd.DataFrame,
                 video_name: str,
                 file_path: str,
                 verbose = True
                 ) -> None:

        # data
        self.helmets = helmets
        self.tracking = tracking.reset_index(drop=True)
        self.video_name = video_name
        self.file_path = file_path
        self.game_play = f'{video_name.split("_")[0]}_{video_name.split("_")[1]}'
        self.game = int(self.game_play.split('_')[0])
        self.playID = int(self.game_play.split('_')[1])
        self.orientation = video_name.split('_')[2]        
        self.game_orientation = f'{self.game_play}_{self.orientation}'
        self.VERBOSE = verbose
        self.video_bboxes_lab_color = None

        # params
        self.is_camera_reversed = None
        self.num_of_frames = None

        # tracking info
        self.sync_frame = None
        self.all_players = self.tracking.player.drop_duplicates().values.tolist()
        self.current_players = self.all_players.copy()
        self.outside_players = []
        self.init_player_prev_pos = {}
        self.init_current_players = []
        self.init_outside_players = []

        # outputs
        self.all_predictions = pd.DataFrame(columns=['video_frame','label','left','width','top','height'])
        self.predictions_by_frame = {}        
        self.frames = None

    def assign(self) -> pd.DataFrame:

        # extract relevant pixel data from video
        self.extract_bboxes_pixels_from_video()

        if self.num_of_frames > 0:

            # estimate initial parameters
            self.estimate_init_parameters()

            # process frames forward/backwards
            if self.sync_frame == 1:
                self.process_frames(1, True)
            elif self.sync_frame == self.num_of_frames:
                self.process_frames(self.num_of_frames, False)
            else:
                self.process_frames(self.sync_frame, False)
                self.process_frames(self.sync_frame, True)

            # update outputs
            self.predictions_by_frame = dict(sorted(self.predictions_by_frame.items(), key=lambda item: item[0]))
            self.frames = np.arange(1, self.num_of_frames + 1)

    def extract_bboxes_pixels_from_video(self) -> None:

        video_name = self.video_name
        vidcap = cv2.VideoCapture(self.file_path)

        frame = 0
        frame_colors = {}
        while True:
            it_worked, img = vidcap.read()

            if not it_worked:
                break
            frame += 1

            video_frame = f'{video_name}_{frame}'

            frame_bboxes = self.helmets.query('video_frame == @video_frame')

            # store 3-Means centroids and percents for individual bboxes
            rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            bbox_colors = {}
            n_clusters = 2
            frame_bboxes = frame_bboxes.sort_values(by=['conf'], ascending=False).head(22)
            for _, fb in frame_bboxes.iterrows():
                if frame <= 50:
                    l, w, t, h = int(fb.left), int(fb.width), int(fb.top), int(fb.height)
                    bbox_img = rgb_img[t:t+h,l:l+w,:]
                    bbox_img = bbox_img.reshape((bbox_img.shape[1]*bbox_img.shape[0],3))
                    if bbox_img.shape[0] > 5 and bbox_img.shape[0] > 5:
                        kmeans=KMeans(n_clusters=n_clusters) # test mini batch to speed up
                        kmeans.fit(bbox_img)
                        centroid=kmeans.cluster_centers_
                        percent, labels = [], list(kmeans.labels_)
                        for i in range(len(centroid)):
                            j=labels.count(i)
                            j=j/(len(labels))
                            percent.append(j)
                        centroid = np.array([int(j) for x in centroid for j in x]).reshape(-1,3)
                        normalized_centroid = np.array(np.ones((1, 1, 3)) * centroid/255)
                        bbox_colors[f'{l}_{w}_{t}_{h}_{fb.conf}_{video_frame}'] = (color.rgb2lab(normalized_centroid), percent, '')
            frame_colors[video_frame] = bbox_colors

        self.video_bboxes_lab_color = frame_colors
        self.num_of_frames = frame

    def estimate_init_parameters(self) -> None:

        sideline = True if self.video_name.split('_')[2] == 'Sideline' else False

        if sideline:
            # find best frame to use in sideline init estimation
            candidates = []
            for k in range(1,50+1):
                candidates.append(f'{self.video_name}_{k}')
            ct_helmets = self.helmets.query('video_frame in @candidates')
            gp = ct_helmets.groupby('video_frame')
            num_high_conf = None
            for k, h in gp:
                number_high_conf_helmets = h.query('conf >= 0.65').shape[0]
                if number_high_conf_helmets == 22:
                    num_high_conf = number_high_conf_helmets
                    best_frame = k
                    break
            if num_high_conf is not None:
                self.sync_frame = int(best_frame.split('_')[3])
            else:
                self.sync_frame = 1
                if self.VERBOSE:
                    print('WARNING: no suitable frame found. defaulting to F1')
        else:
            # find best frame to use in endzone init estimation
            candidates = []
            for k in range(1,50+1):
                candidates.append(f'{self.video_name}_{k}')
            ct_helmets = self.helmets.query('video_frame in @candidates')
            gp = ct_helmets.groupby('video_frame')
            helms_dist = {}
            for k, h in gp:
                num_hc_helms = h.query('conf >= 0.72').shape[0]
                num_lc_helms = h.query('conf < 0.72').shape[0]
                helms_dist[k.split('_')[3]] = (num_hc_helms, num_lc_helms)
            if len(list(helms_dist.keys())) > 0:
                conf_by_num_helm = list(sorted(helms_dist.items(), key=lambda item: item[1][0], reverse=True))
                top_cand = [h for h in conf_by_num_helm if h[1][0] == conf_by_num_helm[0][1][0]]
                top_cand = sorted(top_cand, key=lambda item: item[1][1])
                self.sync_frame = int(top_cand[0][0])

            else:
                print('WARNING: no proper frame was found. Defaulting to frame 1 but it might lead to incorrect predictions')
                self.sync_frame = 1
            safe_bh = [0,1280,0,720]

        video_frame = f'{self.video_name}_{self.sync_frame}'

        # get tracking data
        est_idx = abs(self.tracking['est_frame'] - self.sync_frame).idxmin()
        est_frame = self.tracking.iloc[est_idx, :]['est_frame']
        target_tracking = self.tracking.query('est_frame == @est_frame').reset_index(drop=True)

        # get frame data
        target_helmets = self.helmets.query('video_frame == @video_frame').reset_index(drop=True)

        # convert dataframes to dict
        tracking_data = VideoPredictor.pre_process_tracking_data(target_tracking)
        helmets_data = VideoPredictor.pre_process_helmets_data(target_helmets)

        # get L*a*b data
        bboxes_lab = self.video_bboxes_lab_color[video_frame]

        # estimate initial parameters
        if sideline:
            # find in-frame helmets
            bh, angle, is_reversed = self.find_helmets(helmets=helmets_data.copy(), tracking=tracking_data.copy())
            # apply margins to account for distortions in distances
            safety_margin_top = 4*[h[6] for h in helmets_data.values() if h[1] == bh[1]][0]
            safety_margin_bottom = 4*[h[6] for h in helmets_data.values() if h[1] == bh[3]][0]
            safe_bh = [bh[0], max(0, bh[1] - safety_margin_top), bh[2], min(720, bh[3] + safety_margin_bottom)]
            # filter helmets
            helmets_data = {k:h for k,h in helmets_data.items() if h[0] >= safe_bh[0] and \
                h[1] >= safe_bh[1] and h[0] <= safe_bh[2] and h[1] <= safe_bh[3]}
        else:
            # find two sets of in/out players
            current_players1, current_players2, outside_players1, outside_players2, = self.find_players(helmets=helmets_data.copy(),
                                                                            tracking=tracking_data.copy(), bboxes_lab=bboxes_lab,
                                                                            video_frame=video_frame, all_players=self.all_players)
            # pick set of players that min full prediction
            min_res = float('inf')
            best_current_players, best_outside_players = None, None
            for current_players, outside_players in [(current_players1, outside_players1), (current_players2, outside_players2)]:
                for is_camera_reversed in [True, False]:
                    frame_predictor = FramePredictor(helmets=helmets_data.copy(), tracking=tracking_data.copy(),
                                                     reversed_camera=is_camera_reversed, bboxes_lab=bboxes_lab,
                                                     current_players=current_players, out_players=outside_players,
                                                     video_frame=video_frame,
                                                     grid_rotation_angles=np.arange(-4.5,5.0,0.5))
                    frame_predictor.predict()
                    if frame_predictor.residual < min_res:
                        min_res = frame_predictor.residual
                        is_reversed = is_camera_reversed
                        angle = frame_predictor.best_angle
                        best_current_players = current_players
                        best_outside_players = outside_players

            self.current_players = best_current_players
            self.outside_players = best_outside_players

        self.is_camera_reversed = is_reversed
        self.rotation_angle = angle

        # run fine-grained prediction to update parameters
        frame_predictor = FramePredictor(helmets=helmets_data.copy(), tracking=tracking_data.copy(),
                                         reversed_camera=self.is_camera_reversed, bboxes_lab=bboxes_lab,
                                         current_players=self.current_players, out_players=self.outside_players,
                                         video_frame=video_frame,
                                         grid_rotation_angles=np.arange(self.rotation_angle - 10.0, self.rotation_angle + 10.1, 2.5))
        frame_predictor.predict()

        # when sideline and performance of fine-grainded predictor collapes, fallback to a cut-bordered version
        if frame_predictor.residual > 100 and sideline:
            if self.VERBOSE:
                print('initial bounds lead to a very poor performance. Defaulting to 20% margin instead')
            safe_bh = [0, 144, 1280, 576]
            helmets_data = {k:h for k,h in helmets_data.items() if h[0] >= safe_bh[0] and \
                h[1] >= safe_bh[1] and h[0] <= safe_bh[2] and h[1] <= safe_bh[3]}
            min_res = float('inf')
            for is_camera_reversed in [True, False]:
                frame_predictor = FramePredictor(helmets=helmets_data.copy(), tracking=tracking_data.copy(),
                                                reversed_camera=is_camera_reversed, bboxes_lab=bboxes_lab,
                                                current_players=self.current_players, out_players=self.outside_players,
                                                video_frame=video_frame,
                                                grid_rotation_angles=np.arange(-25.0, 25.1, 2.5))
                frame_predictor.predict()
                if frame_predictor.residual < min_res:
                    min_res = frame_predictor.residual
                    self.is_camera_reversed = is_camera_reversed

        # update fine-grained rotation angle
        self.rotation_angle = frame_predictor.best_angle

        # update tracking info
        for _, p in frame_predictor.predictions.iterrows():
            self.init_player_prev_pos[p.label] = (p.left, p.top, \
                    f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}', p.width, p.height, False, self.sync_frame, \
                        frame_predictor.residual)
            self.init_current_players = self.current_players.copy()
            self.init_outside_players = self.outside_players.copy()

        # update L*a*b
        for _, p in frame_predictor.predictions.iterrows():
            bbox = f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}_{video_frame}'
            if bbox in bboxes_lab:
                bboxes_lab[bbox] = (bboxes_lab[bbox][0], p.label[0])

        if self.VERBOSE:
            print(f'estimated video frame: {video_frame}')
            print(f'estimated camera orientation (reversed): {self.is_camera_reversed}')
            print(f'estimated init angle (degrees): {self.rotation_angle}')
            print(f'estimated init current players: {len(self.current_players)}', self.current_players)
            print(f'estimated init outside players: {len(self.outside_players)}', self.outside_players)
            print(f'estimated bounds: {safe_bh}')
            print(f'init: {frame_predictor.predictions}')

    def process_frames(self, init_frame: int, forward: bool) -> None:

        if forward:
            self.frames = np.arange(init_frame, self.num_of_frames + 1)
        else:
            self.frames = np.arange(1, init_frame)[::-1]

        angle = self.rotation_angle if self.rotation_angle is not None else 0.0 
        if f'{self.video_name}_{self.frames[0]}' in self.video_bboxes_lab_color:
            bboxes_lab = self.video_bboxes_lab_color[f'{self.video_name}_{self.frames[0]}']
        else:
            bboxes_lab = {}

        predictions_by_frame = {}
        player_prev_pos = self.init_player_prev_pos.copy()
        current_players = self.init_current_players.copy()
        outside_players = self.init_outside_players.copy()

        for frame in self.frames:

            video_frame = f'{self.video_name}_{frame}'

            # get L*a*b data
            if forward and frame > 1 and video_frame in self.video_bboxes_lab_color:
                bboxes_lab1 = self.video_bboxes_lab_color[video_frame]
                bboxes_lab = {**{k:p for k,p in bboxes_lab.items() if int(k.split('_')[8]) == frame-1}, \
                    **bboxes_lab1}
            elif not forward and frame < self.num_of_frames and video_frame in self.video_bboxes_lab_color:
                bboxes_lab1 = self.video_bboxes_lab_color[video_frame]
                bboxes_lab = {**{k:p for k,p in bboxes_lab.items() if int(k.split('_')[8]) == frame+1}, \
                    **bboxes_lab1}

            # get frame data
            target_helmets = self.helmets.query('video_frame == @video_frame').reset_index(drop=True)
            if target_helmets.shape[0] == 0:
                predictions_by_frame[frame] = pd.DataFrame(columns=['video_frame','label','left','width','top','height','conf','dist','team'])
                continue 
            helmets_data = VideoPredictor.pre_process_helmets_data(target_helmets)

            # get tracking data
            est_idx = abs(self.tracking['est_frame'] - frame).idxmin()
            est_frame = self.tracking.iloc[est_idx, :]['est_frame']
            target_tracking = self.tracking.query('est_frame == @est_frame')
            if target_tracking.shape[0] == 0:
                predictions_by_frame[frame] = pd.DataFrame(columns=['video_frame','label','left','width','top','height','conf','dist','team'])
                continue
            tracking_data = VideoPredictor.pre_process_tracking_data(target_tracking)

            # predict
            frame_predictor = FramePredictor(helmets=helmets_data, tracking=tracking_data,
                                             bboxes_lab=bboxes_lab,
                                             reversed_camera=self.is_camera_reversed, video_frame=video_frame,
                                             current_players=current_players, out_players=outside_players, 
                                             player_prev_pos=player_prev_pos,
                                             grid_rotation_angles=[angle])
            frame_predictor.predict()

            # get updated L*a*b mapping
            bboxes_lab = frame_predictor.bboxes_lab

            # get updated list of current players from FramePredictor; update list of outside players accordingly
            current_players = frame_predictor.current_players
            outside_players = [p for p in self.all_players if p not in current_players]

            # update tracking info
            player_prev_pos = frame_predictor.player_prev_pos.copy()
            for _, p in frame_predictor.predictions.iterrows():
                if p.label in player_prev_pos:
                    if player_prev_pos[p.label][5]:
                        # stale tracking info
                        min_dd, pmin = float('inf'), None
                        for k, dp in frame_predictor.player_prev_pos.items():
                            if k != p.label:
                                dd = math.sqrt((p.left - dp[0])**2 + (p.top - dp[1])**2)/max(p.width, p.height, 1)
                                if dd < min_dd:
                                    min_dd = dd
                                    pmin = k
                        if p.conf > 0.65 and (pmin is None or pmin == p.label or min_dd > 2.0):
                            if self.VERBOSE:
                                print(p.label, 'remove tracking from stale')
                            player_prev_pos[p.label] = (p.left, p.top, f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}', \
                                p.width, p.height, False, frame, frame_predictor.residual)
                    else:
                        # hot tracking info
                        prev_p = player_prev_pos[p.label]
                        dd = math.sqrt((p.left - prev_p[0])**2 + (p.top - prev_p[1])**2)/max(p.width, p.height, 1)
                        if dd < 0.5:
                            player_prev_pos[p.label] = (p.left, p.top, f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}', \
                                p.width, p.height, False, frame, frame_predictor.residual)
                        else:
                            if self.VERBOSE:
                                print(p.label, 'add tracking to stale', player_prev_pos[p.label][6], f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}')
                            player_prev_pos[p.label] = (*player_prev_pos[p.label][:5], True, \
                                *player_prev_pos[p.label][6:])
                else:
                    player_prev_pos[p.label] = (p.left, p.top, f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}', \
                        p.width, p.height, False, frame, frame_predictor.residual)

            # save predictions dataframe
            predictions_by_frame[frame] = frame_predictor.predictions.copy()

            if self.VERBOSE:
                print(video_frame, frame_predictor.residual)
                print(video_frame, outside_players)

        for p in predictions_by_frame.values():
            self.all_predictions = self.all_predictions.append(p)
        self.predictions_by_frame = {**predictions_by_frame, **self.predictions_by_frame}

    def find_players(self, helmets: dict, tracking: dict, bboxes_lab: dict,
                     video_frame: str, all_players: List) -> Tuple[List, List, float, bool]:

        helmets = {k:h for k,h in helmets.items() if h[4] > 0.7}
        
        lr_helmets = [h[0] for h in list(sorted(helmets.items(), key=lambda item: item[1][0]))]
        rl_helmets = [h[0] for h in list(sorted(helmets.items(), key=lambda item: item[1][0], reverse=True))]
        td_helmets = [h[0] for h in list(sorted(helmets.items(), key=lambda item: item[1][1]))]
        bu_helmets = [h[0] for h in list(sorted(helmets.items(), key=lambda item: item[1][1], reverse=True))]

        dh = Distances({k: h[0] for k,h in helmets.items()}, {k: h[1] for k,h in helmets.items()})
        left_bbox = lr_helmets[0]
        right_bbox = rl_helmets[0]
        top_bbox = td_helmets[0]
        bottom_bbox = bu_helmets[0]

        h_abs_x1 = sum(list(dh.distances_x[left_bbox].values()))
        h_abs_x2 = sum(list(dh.distances_x[right_bbox].values()))
        h_abs_y1 = sum(list(dh.distances_y[top_bbox].values()))
        h_abs_y2 = sum(list(dh.distances_y[bottom_bbox].values()))

        # predominant L*a*b colors of both helmets (most at the top and most at bottom of frame)
        top_bbox_lab, top_bbox_percent = bboxes_lab[f'{top_bbox}_{video_frame}'][0], bboxes_lab[f'{top_bbox}_{video_frame}'][1]
        bottom_bbox_lab, bottom_bbox_percent = bboxes_lab[f'{bottom_bbox}_{video_frame}'][0], bboxes_lab[f'{bottom_bbox}_{video_frame}'][1]
        top_idx_sort = sorted(top_bbox_percent, reverse=True)
        bottom_idx_sort = sorted(bottom_bbox_percent, reverse=True)

        top_same_team, bottom_same_team, unsure_team = [], [], []
        for k, _ in helmets.items():
            if f'{k}_{video_frame}' in bboxes_lab:
                lab, perc = bboxes_lab[f'{k}_{video_frame}'][0], bboxes_lab[f'{k}_{video_frame}'][1]
                xx = sorted(perc, reverse=True)
                
                # top
                top_total_metric = 0
                cp_top_idx_sort = top_idx_sort.copy()
                for j in range(1):
                    j_k = lab[0, perc.index(xx[j]), :]
                    min_metric, min_t = float('inf'), None
                    for t in range(1-j):
                        t_top = top_bbox_lab[0, top_bbox_percent.index(cp_top_idx_sort[t]), :]
                        delta_e = np.linalg.norm(j_k - t_top)
                        metric = delta_e
                        if metric < min_metric:
                            min_metric = metric
                            min_t = t
                    if min_t is not None:
                        top_total_metric += min_metric
                        del cp_top_idx_sort[min_t]

                # bottom
                bottom_total_metric = 0
                cp_bottom_idx_sort = bottom_idx_sort.copy()
                for j in range(1):
                    j_k = lab[0, perc.index(xx[j]), :]
                    min_metric, min_t = float('inf'), None
                    for t in range(1-j):
                        t_bottom = bottom_bbox_lab[0, bottom_bbox_percent.index(cp_bottom_idx_sort[t]), :]
                        delta_e = np.linalg.norm(j_k - t_bottom)
                        metric = delta_e
                        if metric < min_metric:
                            min_metric = metric
                            min_t = t
                    if min_t is not None:
                        bottom_total_metric += min_metric
                        del cp_bottom_idx_sort[min_t]

                if top_total_metric < bottom_total_metric and top_total_metric < 2.3*5 and \
                    abs(top_total_metric - bottom_total_metric) > 2.3*3:
                    top_same_team.append(k)
                elif bottom_total_metric < top_total_metric and bottom_total_metric < 2.3*5 and \
                    abs(top_total_metric - bottom_total_metric) > 2.3*3:
                    bottom_same_team.append(k)                    
                else:
                    unsure_team.append(k)

        num_top_team_bboxes = len(top_same_team)
        num_bottom_team_bboxes = len(bottom_same_team)

        original_tracking = tracking.copy()

        best_out_players_1, best_out_players_2 = None, None
        for allow_bottom_player_to_move in [False, True]:
            best_min_dist = float('inf')
            for angle in np.arange(-4.5,5.0,0.5):

                tracking = { k: (-d[1], d[0]) for k,d in original_tracking.items()}
                delta = abs(list(sorted(tracking.items(), key=lambda item: item[1][0]))[0][1][0])
                tracking = { k: (d[0] + delta, d[1]) for k,d in tracking.items()}

                angle_in_rad = angle*(math.pi/180)
                rot_tracking = {k: (t[0]*math.cos(angle_in_rad) - t[1]*math.sin(angle_in_rad), \
                    t[0]*math.sin(angle_in_rad) + t[1]*math.cos(angle_in_rad)) for k,t in tracking.items()}

                for reversed_camera in [False, True]:
                    
                    min_dist, out_players = float('inf'), []

                    if reversed_camera:
                        tracking = { k: (53.3 - d[0], d[1]) for k,d in rot_tracking.items()}
                    else:
                        tracking = { k: (d[0], 120 - d[1]) for k,d in rot_tracking.items()}

                    lr_tracking = dict(sorted(tracking.items(), key=lambda item: item[1][0]))
                    rl_tracking = dict(sorted(tracking.items(), key=lambda item: item[1][0], reverse=True))
                    bu_tracking = dict(sorted(tracking.items(), key=lambda item: item[1][1], reverse=True))

                    if allow_bottom_player_to_move:
                        for pl1, t1 in lr_tracking.items():
                            for pl2, t2 in rl_tracking.items():
                                for pl4, t4 in bu_tracking.items():
                                    middle_tracking = {k: t for k,t in tracking.items() if t[0] >= t1[0] and t[0] <= t2[0] and \
                                        t[1] <= t4[1]}
                                    middle_tracking = dict(sorted(middle_tracking.items(), key=lambda item: item[1][1]))
                                    for pl3, t3 in middle_tracking.items():
                                        new_ref = {k: t for k,t in middle_tracking.items() if t[1] >= t3[1]}
                                        if pl1 in new_ref and pl2 in new_ref and pl3 in new_ref and pl4 in new_ref and \
                                            len(list(new_ref.keys())) == len(list(helmets.keys())):
                                            dt = Distances({k: t[0] for k,t in new_ref.items()}, {k: t[1] for k,t in new_ref.items()})
                                            t_abs_x1 = sum(list(dt.distances_x[pl1].values()))
                                            t_abs_x2 = sum(list(dt.distances_x[pl2].values()))
                                            t_abs_y1 = sum(list(dt.distances_y[pl3].values()))
                                            t_abs_y2 = sum(list(dt.distances_y[pl4].values()))
                                            total_dist = abs(h_abs_x1 - t_abs_x1) + abs(h_abs_x2 - t_abs_x2) + \
                                                abs(h_abs_y1 - t_abs_y1) + abs(h_abs_y2 - t_abs_y2)
                                            same_team_top = [pl for pl in new_ref.keys() if pl[0] == pl3[0]]
                                            same_team_bottom = [pl for pl in new_ref.keys() if pl[0] == pl4[0]]
                                            if total_dist < min_dist and len(same_team_top) >= num_top_team_bboxes \
                                                and len(same_team_bottom) >= num_bottom_team_bboxes:
                                                min_dist = total_dist
                                                out_players = [pl for pl in tracking.keys() if tracking[pl][0] < t1[0] \
                                                    or tracking[pl][0] > t2[0] or tracking[pl][1] < t3[1] \
                                                        or tracking[pl][1] > t4[1]]
                    else:
                        bottom_player = list(sorted(tracking.items(), key=lambda item: item[1][1]))[-1][0]
                        for pl1, t1 in lr_tracking.items():
                            for pl2, t2 in rl_tracking.items():
                                middle_tracking = {k: t for k,t in tracking.items() if t[0] >= t1[0] and t[0] <= t2[0]}
                                middle_tracking = dict(sorted(middle_tracking.items(), key=lambda item: item[1][1]))
                                for pl3, t3 in middle_tracking.items():
                                    new_ref = {k: t for k,t in middle_tracking.items() if t[1] >= t3[1]}
                                    num_pl_below = len([t for t in middle_tracking.values() if t[1] > tracking[bottom_player][1]])
                                    if pl1 in new_ref and pl2 in new_ref and pl3 in new_ref and bottom_player in new_ref and \
                                        len(list(new_ref.keys())) == len(list(helmets.keys())) and num_pl_below == 0:
                                        dt = Distances({k: t[0] for k,t in new_ref.items()}, {k: t[1] for k,t in new_ref.items()})
                                        t_abs_x1 = sum(list(dt.distances_x[pl1].values()))
                                        t_abs_x2 = sum(list(dt.distances_x[pl2].values()))
                                        t_abs_y1 = sum(list(dt.distances_y[pl3].values()))
                                        t_abs_y2 = sum(list(dt.distances_y[bottom_player].values()))
                                        total_dist = abs(h_abs_x1 - t_abs_x1) + abs(h_abs_x2 - t_abs_x2) + \
                                            abs(h_abs_y1 - t_abs_y1) + abs(h_abs_y2 - t_abs_y2)
                                        same_team_top = [pl for pl in new_ref.keys() if pl[0] == pl3[0]]
                                        same_team_bottom = [pl for pl in new_ref.keys() if pl[0] == bottom_player[0]]
                                        if total_dist < min_dist and len(same_team_top) >= num_top_team_bboxes \
                                            and len(same_team_bottom) >= num_bottom_team_bboxes:
                                            min_dist = total_dist
                                            out_players = [pl for pl in tracking.keys() if tracking[pl][0] < t1[0] \
                                                or tracking[pl][0] > t2[0] or tracking[pl][1] < t3[1]]

                    if min_dist < best_min_dist:
                        best_min_dist = min_dist
                        if allow_bottom_player_to_move:
                            best_out_players_1 = out_players
                        else:
                            best_out_players_2 = out_players

        in_players_1 = [pl for pl in all_players if pl not in best_out_players_1]
        in_players_2 = [pl for pl in all_players if pl not in best_out_players_2]

        return in_players_1, in_players_2, best_out_players_1, best_out_players_2

    def find_helmets(self, helmets: dict, tracking: dict) -> Tuple[List, float, bool]:

        helmets = dict(sorted(helmets.items(), key=lambda item: item[1][4], reverse=True))
        hc_keys = list(helmets.keys())[:30]
        helmets = {k:h for k,h in helmets.items() if k in hc_keys}

        lr_helmets = dict(sorted(helmets.items(), key=lambda item: item[1][0]))
        rl_helmets = dict(sorted(helmets.items(), key=lambda item: item[1][0], reverse=True))
        bu_helmets = dict(sorted(helmets.items(), key=lambda item: item[1][1], reverse=True))

        original_tracking = tracking.copy()
        best_min_dist, best_bounds, best_angle, is_reversed = float('inf'), [], 0.0, False
        for reversed_camera in [True, False]:

            for angle in np.arange(-25.0, 30.0, 5.0):

                if reversed_camera:
                    tracking = { k: (120 - d[0], d[1]) for k,d in original_tracking.items()}
                else:
                    tracking = { k: (d[0], 53.3 - d[1]) for k,d in original_tracking.items()}

                angle_in_rad = angle*(math.pi/180)
                tracking = {k: (t[0]*math.cos(angle_in_rad) - t[1]*math.sin(angle_in_rad), \
                    t[0]*math.sin(angle_in_rad) + t[1]*math.cos(angle_in_rad)) for k,t in tracking.items()}

                dt = Distances({k: t[0] for k,t in tracking.items()}, {k: t[1] for k,t in tracking.items()})
                left_player = list(sorted(tracking.items(), key=lambda item: item[1][0]))[0][0]
                right_player = list(sorted(tracking.items(), key=lambda item: item[1][0]))[-1][0]
                top_player = list(sorted(tracking.items(), key=lambda item: item[1][1]))[0][0]
                bottom_player = list(sorted(tracking.items(), key=lambda item: item[1][1]))[-1][0]

                p_abs_sum_x1 = sum(list(dt.distances_x[left_player].values()))
                p_abs_sum_y1 = sum(list(dt.distances_y[top_player].values()))
                p_abs_sum_x2 = sum(list(dt.distances_x[right_player].values()))
                p_abs_sum_y2 = sum(list(dt.distances_y[bottom_player].values()))

                for bb1, h1 in lr_helmets.items():
                    for bb2, h2 in rl_helmets.items():
                        for bb3, h3 in bu_helmets.items():
                            middle_helmets = {k: h for k,h in helmets.items() if h[0] >= h1[0] and h[0] <= h2[0] and \
                                h[1] <= h3[1]}
                            middle_helmets = dict(sorted(middle_helmets.items(), key=lambda item: item[1][1]))
                            if bb1 in middle_helmets:
                                for bb4, h4 in middle_helmets.items():
                                    new_ref = {k: h for k,h in middle_helmets.items() if h[1] >= h4[1]}
                                    new_ref = dict(sorted(new_ref.items(), key=lambda item: item[1][4], reverse=True))
                                    new_ref = {k:h for k,h in new_ref.items() if \
                                        k in list(new_ref.keys())[:22]}
                                    hel_to_the_left = len([h for h in helmets.values() if h[0] < h1[0] \
                                        and h[1] <= h3[1] and h[1] >= h4[1] and h[4] > 0.7])
                                    hel_to_the_right = len([h for h in helmets.values() if h[0] > h2[0] \
                                        and h[1] <= h3[1] and h[1] >= h4[1] and h[4] > 0.7])
                                    if bb1 in new_ref and bb2 in new_ref and bb3 in new_ref and bb4 in new_ref \
                                        and len(list(new_ref)) == 22 and hel_to_the_right == 0 and hel_to_the_left == 0:
                                        dt = Distances({k: h[0] for k,h in new_ref.items()}, {k: h[1] for k,h in new_ref.items()})
                                        h_abs_x1 = sum(list(dt.distances_x[bb1].values()))
                                        h_abs_y1 = sum(list(dt.distances_y[bb4].values()))
                                        h_abs_x2 = sum(list(dt.distances_x[bb2].values()))
                                        h_abs_y2 = sum(list(dt.distances_y[bb3].values()))
                                        total_dist = abs(p_abs_sum_x1 - h_abs_x1) + abs(p_abs_sum_y1 - h_abs_y1) + \
                                            abs(p_abs_sum_x2 - h_abs_x2) + abs(p_abs_sum_y2 - h_abs_y2)
                                        if total_dist < best_min_dist:
                                            best_min_dist = total_dist
                                            is_reversed = reversed_camera
                                            best_angle = angle
                                            best_bounds = [h1[0], h4[1], h2[0], h3[1]]

        return best_bounds, best_angle, is_reversed

    @staticmethod
    def pre_process_tracking_data(tracking: pd.DataFrame) -> dict:
        """
        Extract data relavant to FramePrediction from dataframe and store it in a simple dict 
        """
        data = {}
        for _, t in tracking.iterrows():
            data[f'{t.player}'] = (t.x, t.y)
        return data

    @staticmethod
    def pre_process_helmets_data(helmets: pd.DataFrame) -> dict:
        """
        Extract data relavant to FramePrediction from dataframe and store it in a simple dict 

        A helmet dict has a key and a value. The key is a combination of a helmet bbox coordinates with the detector
        confidence level. The value is a tuple with the following data: bbox left coordinate; bbox top coordinate;
        default distance from closest bbox in frame; default team assignment; detector confidence level; bbox width;
        bbox height; team assignment confidence level; tracking confidence level.
        """
        data = {}
        for _, h in helmets.iterrows():
            data[f'{h.bbox}_{h.conf}'] = (h.left, h.top, 0.0, '', h.conf, h.width, h.height, 0.5, -99999)
        return data


class FramePredictor:

    def __init__(self,
                 helmets: dict,
                 tracking: dict,
                 reversed_camera: bool,
                 video_frame: str,
                 grid_rotation_angles: np.ndarray,
                 current_players: List[str],
                 out_players: List[str],
                 player_prev_pos: dict = {},
                 full_mode: bool = True,
                 bboxes_lab: dict = None) -> None:

        # data
        self.helmets = helmets
        self.tracking = tracking
        self.reversed_camera = reversed_camera
        self.video_frame = video_frame
        self.bboxes_lab = {k: (p[0].copy(), p[1]) for k,p in bboxes_lab.items() }
        self.player_prev_pos = player_prev_pos
        self.current_players = current_players.copy()
        self.out_players = out_players.copy()

        # caching data structures
        self.helmets_map = {}
        self.tracking_map = {}

        # params
        self.sideline = True if video_frame.split('_')[2] == 'Sideline' else False
        self.grid_rotation_angles = grid_rotation_angles
        self.full_mode = full_mode

        # outputs
        self.predictions = None
        self.residual = None
        self.best_angle = None
        self.low_conf_sum_dist = None

        # camera orientation align
        if self.sideline:
            if self.reversed_camera:
                self.tracking = { k: (120 - d[0], d[1]) for k,d in self.tracking.items()}
            else:
                self.tracking = { k: (d[0], 53.3 - d[1]) for k,d in self.tracking.items()}
        if not self.sideline:
            self.tracking = { k: (-d[1], d[0]) for k,d in self.tracking.items()}
            delta = abs(list(sorted(self.tracking.items(), key=lambda item: item[1][0]))[0][1][0])
            self.tracking = { k: (d[0] + delta, d[1]) for k,d in self.tracking.items()}

            if self.reversed_camera:
                self.tracking = { k: (53.3 - d[0], d[1]) for k,d in self.tracking.items()}
            else:
                self.tracking = { k: (d[0], 120 - d[1]) for k,d in self.tracking.items()}

    def predict(self):

        predictions = None
        best_residual = float('inf')

        # grid search rotation angles
        for angle in self.grid_rotation_angles:

            # rotate x,y points in tracking
            angle_in_rad = angle*(math.pi/180)
            tracking = {k: (p[0]*math.cos(angle_in_rad) - p[1]*math.sin(angle_in_rad), \
                p[0]*math.sin(angle_in_rad) + p[1]*math.cos(angle_in_rad)) for k,p in self.tracking.items()}

            # run a single prediction
            grid_prediction, grid_guess1, grid_guess2 = self.run_single_prediction(self.helmets, tracking)

            # jump to next if no prediction was made
            if grid_prediction is None:
                continue

            # combine residuals
            resid1, resid2 = grid_guess1[0], grid_guess2[0]
            if resid1 == float('inf') and resid2 == float('inf'):
                resid == 0
            elif resid1 == float('inf'):
                resid = resid2
            elif resid2 == float('inf'):
                resid = resid1
            else:
                resid = resid1 + resid2

            # keep prediction that minimizes residual distances
            if resid < best_residual:
                best_residual = resid
                predictions = grid_prediction
                self.best_angle = angle

        if predictions is None:
            predictions = pd.DataFrame(columns=['video_frame','label','left','width','top','height','conf','dist','team'])

        # update outputs
        self.predictions = predictions
        self.residual = best_residual
        if len(list(self.bboxes_lab.keys())) > 0:
            for _, p in predictions.iterrows():
                if p.conf > 0.8:
                    bbox = f'{p.left}_{p.width}_{p.top}_{p.height}_{p.conf}_{self.video_frame}'
                    if bbox in self.bboxes_lab:
                        self.bboxes_lab[bbox] = (self.bboxes_lab[bbox][0], p.label[0])

    def run_single_prediction(self, helmets: dict, tracking: dict) -> Tuple[pd.DataFrame, float, float]:

        # split bboxes into overall high confidence and low confidence helmets
        helmets, tracking = self.sort_helmets(helmets, tracking)

        high_conf_helmets = {k:h for k,h in helmets.items() if h[8] != -99999}
        low_conf_helmets = {k:h for k,h in helmets.items() if k not in high_conf_helmets}

        if len(list(tracking.keys())) == 0 or len(list(high_conf_helmets.keys())) == 0 or \
            len(list(high_conf_helmets.keys())) > len(list(tracking.keys())):
            return None, None, None

        # run prediction for high confidence helmets
        node1, node2, best_guess1, best_guess2, data = self.run_prediction(high_conf_helmets, tracking)
        predictions1, predictions2 = {}, {}
        best_res = 0
        if best_guess1[0] != float('inf'):
            predictions1 = {k: (p,1) for k,p in node1.predict(best_guess1[1]).items()}
            best_res += best_guess1[0]
        if best_guess2[0] != float('inf'):
            best_res += best_guess2[0]
            predictions2 = {k: (p,2) for k,p in node2.predict(best_guess2[1]).items()}
        best_predictions = {**predictions1, **predictions2}
        best_predictions = {k: (p[0], data[k], p[1]) for k,p in best_predictions.items()}

        # run predictions for each low confidence helmet - limited to a max number of executions
        max_runs = 20
        for idx, (k, lc) in enumerate(low_conf_helmets.items()):

            copy_hc_helmets = high_conf_helmets.copy()
            del copy_hc_helmets[list(high_conf_helmets.keys())[-1]]
            copy_hc_helmets[k] = lc
        
            # run prediction
            node1, node2, guess1, guess2, data = self.run_prediction(copy_hc_helmets, tracking)

            predictions1, predictions2 = {}, {}
            res = 0
            if guess1[0] != float('inf'):
                predictions1 = {k: (p,1) for k,p in node1.predict(guess1[1]).items()}
                res += guess1[0]
            if guess2[0] != float('inf'):
                res += guess2[0]
                predictions2 = {k: (p,2) for k,p in node2.predict(guess2[1]).items()}
            predictions = {**predictions1, **predictions2}
            predictions = {k: (p[0], data[k], p[1]) for k,p in predictions.items()}

            # keep predictions of smaller residual
            if res < best_res:
                best_res = res
                best_predictions = predictions.copy()
                best_guess1 = guess1
                best_guess2 = guess2
                high_conf_helmets = copy_hc_helmets.copy()
                if idx + 1>= max_runs:
                    break

        # build predictions dataframe
        out_pred = { idx: [self.video_frame, p[0], p[1][0], p[1][5], p[1][1], p[1][6], p[1][4], p[1][2], p[1][3], p[1][7], p[2]] \
            for idx,p in enumerate(best_predictions.values())}
        predictions_df = pd.DataFrame.from_dict(out_pred, orient='index', \
            columns=['video_frame','label','left','width','top','height','conf','dist','team','team_conf','tier'])

        return predictions_df, best_guess1, best_guess2

    def run_prediction(self, data: dict, tracking: dict) -> Tuple['OptTreeNode', dict]:

        # TODO: fix cache if hkey not in self.helmets_map:
        hkey, tkey = 'h_margin_iter_id', f't_angle'

        bboxes_dist, upd_data = self.estimate_helmet_distance_and_team(data)
        self.helmets_map[hkey] = (bboxes_dist, upd_data)

        labels_dist = self.estimate_tracking_distance(tracking)
        self.tracking_map[tkey] = labels_dist

        node1, node2, best_guess1, best_guess2 = self.optimize(self.tracking_map[tkey], self.helmets_map[hkey])
        return node1, node2, best_guess1, best_guess2, self.helmets_map[hkey][1]

    def estimate_helmet_distance_and_team(self, data: dict) -> 'Distances':

        data = self.calc_helmet_distance(data)
        data = self.estimate_home_team_similarity(data)

        bboxes_x, bboxes_y = {}, {}
        for k, d in data.items():
            bboxes_x[k] = d[0]
            bboxes_y[k] = d[1]

        return Distances(bboxes_x, bboxes_y), data

    def calc_helmet_distance(self, data: dict) -> dict:
        """
        calculate distances btwn points in a set of helmets
        """
        for k1, d1 in data.items():
            min_dist = float('inf')
            for k2, d2 in data.items():
                if k1 != k2:
                    dist = math.sqrt( (d1[0] - d2[0])**2 + (d1[1] - d2[1])**2 ) / max(d1[5], d1[6], 1)
                    if dist < min_dist:
                        min_dist = dist
            if min_dist != float('inf') :
                data[k1] = (data[k1][0], data[k1][1], min_dist, data[k1][3], data[k1][4], data[k1][5], data[k1][6], data[k1][7])

        if len(list(self.player_prev_pos.items())) > 0:

            taken = []
            same_dist = []
            for k,h in data.items():
                if h[2] not in taken:
                    taken.append(h[2])
                else:
                    same_dist.append(h[2])
            for d1 in same_dist:
                equidistant = {k:h for k,h in data.items() if data[k][2] == d1}
                min_dist, closest_k = float('inf'), None
                for k2, e in equidistant.items():
                    dist = float('inf')
                    for k3, d3 in self.player_prev_pos.items():
                        if not d3[5]:
                            dd = math.sqrt( (e[0] - d3[0])**2 + (e[1] - d3[1])**2 ) / max(e[5], e[6], 1)
                            if dd < dist:
                                dist = dd
                    if dist != float('inf') and dist < min_dist:
                        min_dist = dist
                        closest_k = k2
                if closest_k is not None:
                    data[closest_k] = (*data[closest_k][:2], data[closest_k][2] + 0.001, *data[closest_k][3:])

        return dict(sorted(data.items(), key=lambda item: item[1][2], reverse=True))

    def estimate_tracking_distance(self, tracking: dict) -> 'Distances':

        labels_x = {k:t[0] for k,t in tracking.items()}
        labels_y = {k:t[1] for k,t in tracking.items()}

        return Distances(labels_x, labels_y)

    def optimize(self, distances_labels: 'Distances', bboxes_data: dict) -> Tuple['OptTreeNode', dict]:

        distances_bboxes = bboxes_data[0]
        bboxes_dist = {k: d[2] for k,d in bboxes_data[1].items()}

        sum_residuals_x, sum_residuals_y = self.calculate_residuals(distances_labels, distances_bboxes, bboxes_data[1])
        labels = list(distances_labels.distances_x.keys())
        bboxes = list(distances_bboxes.distances_x.keys())

        if self.full_mode:
            distant_bboxes = [b for b in bboxes if bboxes_dist[b] >= 5]
            near_bboxes = [b for b in bboxes if bboxes_dist[b] < 5]
        else:
            distant_bboxes = [b for b in bboxes if bboxes_dist[b] >= 3]
            near_bboxes = [b for b in bboxes if bboxes_dist[b] < 3]

        # distant bboxes
        node = OptTreeNode(available_labels=labels, sum_residuals_x=sum_residuals_x,
                           sum_residuals_y=sum_residuals_y, bboxes=bboxes)
        total_number_bboxes = len(distant_bboxes) + len(near_bboxes)
        number_bboxes, number_players = len(distant_bboxes), len(labels)

        path = [0] * total_number_bboxes
        best_guess = (float('inf'), path)

        largest_metric = float('-inf')
        number_of_bboxes_processed, number_of_bboxes_skipped = 0, 0

        if self.full_mode:
            for idx in range(number_bboxes):
                best_guess = (float('inf'), path)
                for k in range(min(number_bboxes - idx, number_players)):
                    if k != min(number_bboxes - idx, number_players) - 1:
                        for j in range(min(number_bboxes - idx, number_players, 2) - 1):
                            shortest_path = best_guess[1].copy()
                            shortest_path[idx] = k
                            shortest_path[min(idx+1,number_bboxes-1)] = j

                            for _ in range(len(path)):
                                if len(node.child) == 0:
                                    node.step()
                                node = node.child[shortest_path.pop(0)]

                            if node.metric < best_guess[0]:
                                path, _ = node.get_path()
                                best_guess = (node.metric, path)
                            node = node.get_root()            
                    else:
                        shortest_path = best_guess[1].copy()
                        shortest_path[idx] = k

                        for _ in range(len(path)):
                            if len(node.child) == 0:
                                node.step()
                            node = node.child[shortest_path.pop(0)]

                        if node.metric < best_guess[0]:
                            path, _ = node.get_path()
                            best_guess = (node.metric, path)
                        node = node.get_root()
            number_of_bboxes_processed = number_bboxes
        else:
            for idx in range(number_bboxes):

                best_guess = (float('inf'), path)

                for k in range(min(number_bboxes - idx, 1, number_players)):
                    shortest_path = best_guess[1].copy()

                    # navigate until leaf by taking shortest_path
                    for _ in range(total_number_bboxes - number_of_bboxes_skipped):
                        if len(node.child) == 0:
                            node.step()
                        idx_path = shortest_path.pop(0)
                        node = node.child[idx_path]

                    if node.metric < best_guess[0]:
                        path, _ = node.get_path()
                        best_guess = (node.metric, path)
                    node = node.get_root()

                # verify residual at level idx, i.e. sum of residuals to label bboxes from root to current in loop
                for i in range(number_of_bboxes_processed + 1):
                    node = node.child[path[i]]
                if node.parent.metric is not None:
                    metric = node.metric - node.parent.metric
                else:
                    metric = node.metric
                if metric > 2*largest_metric and largest_metric != float('-inf'):
                    number_of_bboxes_skipped += 1
                    node.parent.child = []
                    node.parent.bboxes.remove(node.bbox)
                    near_bboxes.append(node.bbox)
                    del path[0] # TODO: be careful if k ranges to more than one, index to delete is not idx
                else:
                    number_of_bboxes_processed += 1
                    if metric > largest_metric:
                        largest_metric = metric
                node = node.get_root()

        for i in range(number_of_bboxes_processed):
            node = node.child[best_guess[1][i]]
        best_guess = (node.metric, best_guess[1][:number_of_bboxes_processed])
        node = node.get_root()

        # near bboxes
        remaining_labels = [l for l in labels if l not in node.get_label_from_path(best_guess[1])]
        near_node = OptTreeNode(available_labels=remaining_labels, bboxes=near_bboxes,
                                sum_residuals_x=sum_residuals_x, sum_residuals_y=sum_residuals_y)
        number_bboxes, number_players = len(near_bboxes), len(remaining_labels)

        path = [0] * number_bboxes
        near_best_guess = (float('inf'), path)

        for idx in range(number_bboxes):
            near_best_guess = (float('inf'), path)

            for k in range(min(number_bboxes - idx, number_players)):

                if k != min(number_bboxes - idx, number_players) - 1:
                    for j in range(min(number_bboxes - idx, number_players, 2) - 1):
                        shortest_path = near_best_guess[1].copy()
                        shortest_path[idx] = k
                        shortest_path[min(idx+1,number_bboxes-1)] = j

                        for _ in range(len(path)):
                            if len(near_node.child) == 0:
                                near_node.step()
                            near_node = near_node.child[shortest_path.pop(0)]

                        if near_node.metric < near_best_guess[0]:
                            path, _ = near_node.get_path()
                            near_best_guess = (near_node.metric, path)
                        near_node = near_node.get_root()            
                else:
                    shortest_path = near_best_guess[1].copy()
                    shortest_path[idx] = k

                    for _ in range(len(path)):
                        if len(near_node.child) == 0:
                            near_node.step()
                        near_node = near_node.child[shortest_path.pop(0)]

                    if near_node.metric < near_best_guess[0]:
                        path, _ = near_node.get_path()
                        near_best_guess = (near_node.metric, path)
                    near_node = near_node.get_root()

        return node, near_node, best_guess, near_best_guess

    def calculate_residuals(self, distances_labels: 'Distances', distances_bboxes: 'Distances',
                            bboxes_data: dict) -> Tuple[dict, dict]:

        # on X
        bbox_dx = {}
        for bbox, dists in distances_bboxes.distances_x.items():
            sd = 0
            for _, d in dists.items():
                sd += d
            bbox_dx[bbox] = sd

        label_dx = {}
        for label, dists in distances_labels.distances_x.items():
            sd = 0
            for _, d in dists.items():
                sd += d
            label_dx[label] = sd

        sum_residuals_x = {}
        for bbox, sb in bbox_dx.items():
            sum_res = {}
            for label, sl in label_dx.items():
                tracking_penalization = self.add_tracking_penalization(label=label, bbox_data=bboxes_data[bbox])
                team_penalization = self.add_team_penalization(label=label, bbox_data=bboxes_data[bbox])
                sum_res[label] = abs(sl - sb) + tracking_penalization #+ team_penalization
            sum_residuals_x[bbox] = sum_res

        # on Y
        bbox_dy = {}
        for bbox, dists in distances_bboxes.distances_y.items():
            sd = 0
            for _, d in dists.items():
                sd += d
            bbox_dy[bbox] = sd

        label_dy = {}
        for label, dists in distances_labels.distances_y.items():
            sd = 0
            for _, d in dists.items():
                sd += d
            label_dy[label] = sd

        sum_residuals_y = {}
        for bbox, sb in bbox_dy.items():
            sum_res = {}
            for label, sl in label_dy.items():
                tracking_penalization = self.add_tracking_penalization(label=label, bbox_data=bboxes_data[bbox])
                team_penalization = self.add_team_penalization(label=label, bbox_data=bboxes_data[bbox])                
                sum_res[label] = abs(sl - sb) + tracking_penalization #+ team_penalization
            sum_residuals_y[bbox] = sum_res
    
        return sum_residuals_x, sum_residuals_y

    def add_tracking_penalization(self, label: str, bbox_data: str) -> float:

        if label not in self.player_prev_pos:
            return 1.0  # prioritize close players (< 1 normalized helmet away)
            # TODO: the ultimate solution is probably to not del from player_prev_pos but keep the last seen position
            # and use it as rough estimation.
        
        px, py = self.player_prev_pos[label][0], self.player_prev_pos[label][1]
        bl, bw, bt, bh = bbox_data[0], bbox_data[5], bbox_data[1], bbox_data[6]

        if self.player_prev_pos[label][5]:
            # stale position
            k = abs(int(self.video_frame.split('_')[3]) - self.player_prev_pos[label][6])
            distance = (1/k)*math.sqrt((bl - px)**2 + (bt - py)**2)/max(bw, bh, 1)
            distance = 1.0
        else:
            distance = math.sqrt((bl - px)**2 + (bt - py)**2)/max(bw, bh, 1)
        
        return distance

    def add_team_penalization(self, label: str, bbox_data: str) -> float:
        home_team_conf = bbox_data[7]
        if label[0] == 'V':
            return 5.0*home_team_conf
        else:
            return 5.0*(1-home_team_conf)

    def estimate_home_team_similarity(self, data: dict) -> dict:

        home_bboxes_lab = {k: l for k,l in self.bboxes_lab.items() if l[1] == 'H'}

        if len(list(home_bboxes_lab.keys())) == 0:
            return data

        # TODO: change to better aggregate (not only first bbox)
        home_lab = home_bboxes_lab[next(iter(home_bboxes_lab))][0]

        values = {}
        for k, _ in data.items():
            if f'{k}_{self.video_frame}' in self.bboxes_lab:
                delta_e = np.linalg.norm(self.bboxes_lab[f'{k}_{self.video_frame}'][0] - home_lab)
                values[k] = delta_e
        if len(list(values.keys())) > 0:
            max_value = max(list(values.values()))
        for k, _ in data.items():
            if k in values:
                data[k] = (data[k][0], data[k][1], data[k][2], data[k][3], data[k][4], data[k][5], data[k][6], 1 - values[k]/max_value)

        return data

    def sort_helmets(self, helmets: dict, tracking: dict) -> dict:

        filtered_helmets = dict(sorted(helmets.items(), key=lambda item: item[1][4], reverse=True))

        if len(list(self.player_prev_pos.keys())) > 0:

            # process outbound players
            mutable_prev_pos = self.player_prev_pos.copy()
            for p, dp in self.player_prev_pos.items():
                if dp[0] <= 2 or dp[0] + dp[3] >= 1278 or dp[1] <= 2 or dp[1] + dp[4] >= 718:
                    min_dd, pmin = float('inf'), None
                    for h in self.helmets.values():
                        dd = math.sqrt((h[0] - dp[0])**2 + (h[1] - dp[1])**2)/max(dp[3], dp[4], 1)
                        if dd < min_dd:
                            min_dd = dd
                            pmin = p
                    if min_dd != float('inf') and min_dd > 1:
                        if pmin in self.current_players:
                            self.current_players.remove(pmin)
                            del mutable_prev_pos[pmin]
            self.player_prev_pos = mutable_prev_pos.copy()

            # process inbound players
            inbound_helmets = {}
            if len(self.out_players) > 0:
                boundary_helmets, inbound_directions = FramePredictor.get_boundary_helmets(self.helmets)
                for k1, h in boundary_helmets.items():
                    # check if boundary helmet wasn't there already
                    min_dd, pmin = float('inf'), None
                    for p, dp in self.player_prev_pos.items():
                        dd = math.sqrt((h[0] - dp[0])**2 + (h[1] - dp[1])**2)/max(dp[3], dp[4], 1)
                        if dd < min_dd:
                            min_dd = dd
                            pmin = p
                    if min_dd != float('inf') and min_dd <= 1:
                        continue
                    
                    # list candidates based on inbound direction
                    if inbound_directions[k1] == 'left':
                        pl_at_the_left = list(sorted(self.player_prev_pos.items(), key=lambda item: item[1][0]))[0][0]
                        candidates = [pl for pl in self.out_players if pl in tracking \
                            and pl_at_the_left in tracking and tracking[pl][0] < tracking[pl_at_the_left][0]]
                    if inbound_directions[k1] == 'top':
                        pl_at_the_top = list(sorted(self.player_prev_pos.items(), key=lambda item: item[1][1]))[0][0]
                        candidates = [pl for pl in self.out_players if pl in tracking \
                            and pl_at_the_top in tracking and tracking[pl][1] < tracking[pl_at_the_top][1]]
                    if inbound_directions[k1] == 'right':
                        pl_at_the_right = list(sorted(self.player_prev_pos.items(), key=lambda item: item[1][0], reverse=True))[0][0]
                        candidates = [pl for pl in self.out_players if pl in tracking \
                            and pl_at_the_right in tracking and tracking[pl][0] > tracking[pl_at_the_right][0]]
                    if inbound_directions[k1] == 'bottom':
                        pl_at_the_bottom = list(sorted(self.player_prev_pos.items(), key=lambda item: item[1][1], reverse=True))[0][0]
                        candidates = [pl for pl in self.out_players if pl in tracking \
                            and pl_at_the_bottom in tracking and tracking[pl][1] > tracking[pl_at_the_bottom][1]]

                    # consider player to be within frame if the residual of the new matching player/helmet decreases the
                    # average residual from previous frame
                    hc_prev_pos_hel = {p[2]: (p[0], p[1]) for k,p in self.player_prev_pos.items() if not p[5]}
                    hc_prev_pos_pl = {k: tracking[k] for k,p in self.player_prev_pos.items() if k in tracking and not p[5]}                
                    composed_hel = {**hc_prev_pos_hel, **{k:h for k,h in boundary_helmets.items() if k==k1}}
                    dth = Distances({k:d[0] for k,d in composed_hel.items()}, {k:d[1] for k,d in composed_hel.items()})
                    reference_res = next(iter(self.player_prev_pos.values()))[7]/len(list(self.player_prev_pos.keys()))
                    sumx_h, sumy_h = sum(list(dth.distances_x[k1].values())), sum(list(dth.distances_y[k1].values()))
                    min_res, best_cand = reference_res, None
                    for cand in candidates:
                        b = {**hc_prev_pos_pl, **{k:p for k,p in tracking.items() if k == cand}}
                        dtp = Distances({k:d[0] for k,d in b.items()}, {k:d[1] for k,d in b.items()})
                        sumx_p, sumy_p = sum(list(dtp.distances_x[cand].values())), sum(list(dtp.distances_y[cand].values()))
                        cand_res = abs(sumx_h - sumx_p) + abs(sumy_h - sumy_p)
                        if cand_res < min_res:
                            min_res = cand_res
                            best_cand = cand
                    if best_cand != None:
                        self.current_players.append(best_cand)
                        self.out_players.remove(best_cand)
                        inbound_helmets[k1] = h

            # filter tracking with update list of players
            tracking = dict(filter(lambda elem: elem[0] in self.current_players, tracking.items()))
            num_players = len(list(tracking.keys()))
            if num_players == 0:
                return filtered_helmets, tracking

            sorted_prev_pos = dict(sorted(self.player_prev_pos.items(), key=lambda item: item[1][5], reverse=False))

            additional_helmets = {}
            data = {}

            for k, h in inbound_helmets.items():
                data[k] = (*h[:8], 1.0)

            for k1, d1 in sorted_prev_pos.items():
                best_k, min_dist = None, float('inf')
                for k2, d2 in {k:h for k,h in helmets.items() if k not in data}.items():
                    dist = math.sqrt( (d1[0] - d2[0])**2 + (d1[1] - d2[1])**2 ) / max(d2[5], d2[6], 1)
                    if dist < min_dist:
                        min_dist = dist
                        best_k = k2
                if min_dist != float('inf'):
                    if not d1[5]:
                        data[best_k] = (*filtered_helmets[best_k][:8], min_dist)
                    else:
                        # test when player outs frame: it might pick a helmet very distant and give it a HC status
                        data[best_k] = (*filtered_helmets[best_k][:8], -1/filtered_helmets[best_k][4])
                        if min_dist > 1 and abs(int(self.video_frame.split('_')[3]) - d1[6]) < 10:
                            # test 'cause if player left, it shouldnt impact as res of not being there smaller
                            additional_helmets[k1] = d1

            picked = {k:p for k,p in helmets.items() if k in data}
            ignored = {k:p for k,p in helmets.items() if k not in data}

            for k1, d1 in ignored.items():
                if k1 not in inbound_helmets.items():
                    min_dist = float('inf')
                    for k2, d2 in picked.items():
                        dist = math.sqrt( (d1[0] - d2[0])**2 + (d1[1] - d2[1])**2 ) / max(d2[5], d2[6], 1)
                        if dist < min_dist:
                            min_dist = dist
                    # pick points somewhat close to those already there
                    if min_dist != float('inf') and min_dist < 10:
                        data[k1] = (*d1[:8], -99999)

            list_bboxes = [ f'{k.split("_")[0]}_{k.split("_")[1]}_{k.split("_")[2]}_{k.split("_")[3]}' for k in data.keys()]
            for k1, d1 in additional_helmets.items():
                key = d1[2]
                if f'{key.split("_")[0]}_{key.split("_")[1]}_{key.split("_")[2]}_{key.split("_")[3]}' not in list_bboxes:
                    data[d1[2]] = (d1[0], d1[1], 0.0, '', float(d1[2].split('_')[4]), d1[3], d1[4], 0.5, -99999)

            filtered_helmets = dict(sorted(data.items(), key=lambda item: item[1][8], reverse=True))        
        else:
            # update tracking confidence level of top helmets to be considered HC
            tracking = dict(filter(lambda elem: elem[0] in self.current_players, tracking.items()))
            hc_keys = [k for k in filtered_helmets.keys()][:len(self.current_players)]
            hc = {k:(*h[:8], -1.0) for k,h in filtered_helmets.items() if k in hc_keys}
            lc = {k:h for k,h in filtered_helmets.items() if k not in hc_keys}
            filtered_helmets = {**hc, **lc}

        return filtered_helmets, tracking

    @staticmethod
    def get_boundary_helmets(helmets: dict) -> Tuple[dict, dict]:
        bd_helmets = {}
        inbound_direction = {}
        for k,h in helmets.items():
            if h[0] <= 2 or h[0] + h[5] >= 1278 or h[1] <= 2 or h[1] + h[6] >= 718:
                bd_helmets[k] = h
                if h[0] <= 2:
                    inbound_direction[k] = 'left'
                if h[0] + h[5] >= 1278:
                    inbound_direction[k] = 'right'
                if h[1] <= 2:
                    inbound_direction[k] = 'top'
                if h[1] + h[6] >= 718:
                    inbound_direction[k] = 'bottom'
        return bd_helmets, inbound_direction


class Distances:

    def __init__(self,
                 points_x: dict,
                 points_y: dict) -> None:
        self.distances_x, self.distances_y = {}, {}

        for anchor in points_x.keys():
            w_d_p_s = {}
            for k, p in points_x.items():
                d = p - points_x[anchor]
                w_d_p_s[k] = d
            norm_factor = max(list(map(abs,list(w_d_p_s.values()))))
            if norm_factor != 0:
                for k in w_d_p_s.keys():
                    w_d_p_s[k] = w_d_p_s[k]/norm_factor
            self.distances_x[anchor] = w_d_p_s

        for anchor in points_y.keys():
            w_d_p_s = {}
            for k, p in points_y.items():
                d = p - points_y[anchor]
                w_d_p_s[k] = d
            norm_factor = max(list(map(abs,list(w_d_p_s.values()))))
            if norm_factor != 0:
                for k in w_d_p_s.keys():
                    w_d_p_s[k] = w_d_p_s[k]/norm_factor
            self.distances_y[anchor] = w_d_p_s


class OptTreeNode:

    def __init__(self, label: str = None, bbox: str = None,
                 parent: 'OptTreeNode' = None, metric: float = 0,
                 available_labels: List[str] = None, idx_from_parent: int = None,
                 sum_residuals_x: dict = None, sum_residuals_y: dict = None, bboxes: List[str] = None) -> None:
        self.label = label
        self.bbox = bbox
        self.parent = parent
        self.child = []
        self.metric = metric
        self.idx_from_parent = idx_from_parent
        self.available_labels = available_labels        
        if self.parent is None:
            self.sum_residuals_x = sum_residuals_x
            self.sum_residuals_y = sum_residuals_y
            self.bboxes = bboxes
        else:
            self.sum_residuals_x = parent.sum_residuals_x
            self.sum_residuals_y = parent.sum_residuals_y
            self.bboxes = parent.bboxes.copy()

    def _generate(self, weights: list, labels: list, bbox: str) -> None:
        for j in range(len(weights)):
            child_available_labels = self.available_labels[:]
            child_available_labels.remove(labels[j])
            self.child.append(OptTreeNode(labels[j], bbox, self, self.metric + weights[j], \
                child_available_labels, j, None, None, None))

    def step(self) -> None:

        if self.bbox == None:
            bbox = self.bboxes[0]
        else:
            index = self.bboxes.index(self.bbox) + 1
            bbox = self.bboxes[index]

        residuals_x = self.sum_residuals_x[bbox]
        residuals_x = { label: residuals_x[label] for label in self.available_labels }

        residuals_y = self.sum_residuals_y[bbox]
        residuals_y = { label: residuals_y[label] for label in self.available_labels }

        # combined residual
        residuals_combined = {}
        for k, rx in residuals_x.items():
            residuals_combined[k] = rx + residuals_y[k]

        labels_weights = dict(sorted(residuals_combined.items(), key=lambda item: item[1]))
        self._generate(weights=list(labels_weights.values()), labels=list(labels_weights.keys()), bbox=bbox)

    def get_root(self) -> 'OptTreeNode':
        node = self
        while node.parent is not None:
            node = node.parent
        return node

    def get_path(self) -> Tuple[List[int], str]:
        node = self
        path_str = ''
        path = []
        while node.parent is not None:
            path.insert(0, node.idx_from_parent)
            path_str += str(node.idx_from_parent)
            node = node.parent
        return path, path_str

    def get_label_from_path(self, path: list) -> List[str]:
        depth = len(path)
        path = path.copy()
        labels = []
        node = self
        for j in range(depth):
            node = node.child[path[j]]
            labels.append(node.label)
        return labels

    def predict(self, path: list) -> dict:
        depth = len(path)
        path = path.copy()
        prediction = {}
        node = self
        for j in range(depth):
            node = node.child[path[j]]
            prediction[node.bbox] = node.label
        return prediction


class Util:

    @staticmethod
    def add_helmet_features(helmets: pd.DataFrame) -> pd.DataFrame:
        helmets = helmets.copy()
        helmets['bbox'] = helmets['left'].astype(str) + '_' + helmets['width'].astype(str) +\
            '_' + helmets['top'].astype(str) + '_' + helmets['height'].astype(str)
        return helmets

    @staticmethod
    def add_track_features(tracks: pd.DataFrame, fps=59.94, snap_frame=10) -> pd.DataFrame:

        tracks = tracks.copy()
        tracks['game_play'] = (
            tracks['gameKey'].astype('str')
            + "_"
            + tracks['playID'].astype('str').str.zfill(6)
        )
        tracks['time'] = pd.to_datetime(tracks['time'])
        snap_dict = (
            tracks.query('event == "ball_snap"')
            .groupby('game_play')['time']
            .first()
            .to_dict()
        )
        tracks['snap'] = tracks['game_play'].map(snap_dict)
        tracks['isSnap'] = tracks['snap'] == tracks['time']
        tracks['team'] = tracks['player'].str[0].replace('H', 'Home').replace('V', 'Away')
        tracks['snap_offset'] = (tracks['time'] - tracks['snap']).astype('timedelta64[ms]') / 1_000
        tracks['est_frame'] = (
            ((tracks['snap_offset'] * fps) + snap_frame).round().astype('int')
        )
        return tracks


Load data

In [2]:
BASE_DIR = '/kaggle/input/nfl-health-and-safety-helmet-assignment'

# data
test_helmets = pd.read_csv(f'{BASE_DIR}/test_baseline_helmets.csv')
test_helmets = Util.add_helmet_features(test_helmets)
test_tracking = pd.read_csv(f'{BASE_DIR}/test_player_tracking.csv')
test_tracking = Util.add_track_features(test_tracking)

Define helper functions

Process game plays

In [3]:
import glob

files = glob.glob(f'{BASE_DIR}/test/*.mp4', recursive=True)

predictions = []

if len(files) > 6:
    for f in files:

        video_name = f.split('/')[5].replace('.mp4', '')
        game_play = f'{video_name.split("_")[0]}_{video_name.split("_")[1]}'

        video_predictor = VideoPredictor(test_helmets.query('video_frame.str.startswith(@video_name)', engine = 'python'),
                                         test_tracking.query('game_play == @game_play'),
                                         video_name, f, False)
        video_predictor.assign()

        if video_predictor.num_of_frames > 0:
            predictions += list(video_predictor.predictions_by_frame.values())

else:
    print('skipping over train data')

Export submission file

In [4]:
all_predictions = pd.DataFrame()
for p in predictions:
    all_predictions = all_predictions.append(p)

out_path = "/kaggle/working/submission.csv"
if all_predictions.shape[0] > 0:
    all_predictions[['video_frame','label','left','width','top','height']].to_csv(out_path, index=False)
else:
    sample_sub = pd.read_csv(f'{BASE_DIR}/sample_submission.csv')
    sample_sub.to_csv(out_path, index=False)