In [None]:
from google.colab import drive
drive.mount('/gdrive')

In [None]:
import os 
import copy 
import numpy as np 
import matplotlib.pyplot as plt 
from datetime import datetime
from tqdm.notebook import tqdm 

# Event Detection
This notebook gathers predictions made by the model and aggregates them into _events_, i.e., application events which span multiple images. This is done by naively assuming that predictions in neighbouring images (by time) are the same event if the prediction bounding boxes intersect. 

## Helper Functions for aggregating predictions

In [None]:
def rect_intersect(r1, r2): 
  """Return True iff given rectangles intersect. 
  rectangles defined as [x_center, y_center, width, height]"""

  [x1, y1, w1, h1] = r1 
  [x2, y2, w2, h2] = r2

  x1_left = x1 - w1/2 
  x1_right = x1 + w1/2 
  x2_left = x2 - w2/2 
  x2_right = x2 + w2/2 

  # check if either is to the left of the other 
  if (x1_right < x2_left) or (x2_right < x1_left): 
    return False

  top1 = y1 + h1/2 
  bottom1 = y1 - h1/2 
  top2 = y2 + h2/2 
  bottom2 = y2 - h2/2   

  # check if either is below the other 
  if top1 < bottom2 or top2 < bottom1: 
    return False 

  return True 

  


In [None]:
def aggregate_instances(labels_dict, threshold=0): 
    """Given a dictionary of detections (created below), aggregate the 
    individual image-level predictions into events (i.e., instances)"""

  instance_dict = {}
  for loc, sdict in labels_dict.items():
    for season, ddict in sdict.items():

      dates = sorted(ddict.keys())
      
      # keep track of instances for this season
      all_instances = [] 

      for i, date in enumerate(dates): # all images in season in correct order 

        # print(i, labels_dict[loc][season][date])

        if not len(labels_dict[loc][season][date]): 
          # end all instances, this date has no labels
          for inst in all_instances: 
            if inst['active']: 
              inst['active'] = False 
              inst['end'] = dates[i-1] 
          continue

        # grab coordinates (labels), images, and confidences 
        for labels_ims in labels_dict[loc][season][date]:
          
          if len(labels_ims) == 2: 
            label_list, im = labels_ims[0], labels_ims[1]
            confs = [1]*len(label_list)
          else: 
            label_list, im = labels_ims[0], labels_ims[1]
            confs = labels_ims[2]
            conf_bools = np.array([c >= threshold for c in confs]) 
            label_list = np.array(label_list)[conf_bools]
            confs = np.array(confs)[conf_bools]
            assert len(confs) == len(label_list) == np.sum(conf_bools)


          # Check which instances get extended 
          used_labels = [] # track which labels get used  
          for inst in all_instances: 
            if not inst['active']: continue # instance already ended 
            continued = False 
            for j, coords in enumerate(label_list): 
              any_int = np.sum([rect_intersect(coords, c) for c in inst['coords']]) 
              if any_int: # this label is part of current instance 
                inst['ims'].append(im) 
                inst['coords'].append(coords)
                inst['conf'] = max(inst['conf'], confs[j]) # confidence is maximum conf over all detections
                continued = True 
                used_labels.append(j)
            if not continued: # this instance ended, no labels matched 
              inst['active'] = False 
              inst['end'] = dates[i-1] if inst['start'] < date else date

          # Make new instances from any unused labels 
          for j, coords in enumerate(label_list): 
            if j in used_labels: continue 
            new_inst = {'start': date, 
                        'active': True, 
                        'ims': [im], 
                        'coords': [coords], 
                        'conf': confs[j]}
            #print('creating new inst', date, new_inst)
            all_instances.append(new_inst)


      # End of season, so we end all instances 
      for inst in all_instances: 
        if inst['active']: 
          inst['active'] = False 
          inst['end'] = dates[-1]

      if loc not in instance_dict.keys():
        instance_dict[loc] = {}
      
      instance_dict[loc][season] = all_instances   

  return instance_dict 


In [None]:
boxing_pth = '/gdrive/Shareddrives/land-app-groundtruth/boxing_task'
test_ims = os.path.join(boxing_pth, 'classification_task/images')

date_dict = {}
for im in os.listdir(test_ims): 
  loc = im.split('_')[1]
  date = datetime.strptime(im.split('_')[2], '%Y%m%d')
  season = date.year if date.month > 6 else date.year - 1

  if loc not in date_dict.keys(): 
    date_dict[loc] = {2018: {}, 2019: {}, 2020: {}}

  date_dict[loc][season][date] = []


## Ground truth Events

In [None]:
# Create events from ground truth boxes 

with open(os.path.join(boxing_pth, 'test.txt'), 'r') as f: 
  test_set = [l.strip() for l in f.readlines()]

true_labels = copy.deepcopy(date_dict)

for im in test_set: 
  txt_name = im.split(os.sep)[-1].replace('.png', '.txt')
  date = datetime.strptime(txt_name.split('_')[2], '%Y%m%d')  
  loc = txt_name.split('_')[1]
  season = date.year if date.month > 6 else date.year - 1 

  with open(os.path.join(boxing_pth, 'labels', txt_name)) as f: 
    lbls = []
    for l in [l.strip() for l in f.readlines()]: 
      arr = l.split(' ')
      lbls.append([float(x) for x in arr[1:]])

  true_labels[loc][season][date].append((lbls, im.split(os.sep)[-1]))

In [None]:
instance_dict = aggregate_instances(true_labels, threshold=0.5)

## Predicted events

In [None]:
# Predictions

exp = '' # model output folder

pred_labels = copy.deepcopy(date_dict)

for txt in os.listdir(os.path.join(exp, 'labels')): 
  loc = txt.split('_')[1] 
  date = datetime.strptime(txt.split('_')[2], '%Y%m%d') 
  season = date.year if date.month > 6 else date.year - 1 
  
  with open(os.path.join(exp, 'labels', txt), 'r') as f: 
    plbls, confs = [], []
    for l in [l.strip() for l in f.readlines()]: 
      arr = l.split(' ')
      plbls.append([float(x) for x in arr[1:-1]])
      confs.append(float(arr[-1]))

  pred_labels[loc][season][date].append((plbls, txt.replace('.txt', '.png'), confs))

In [None]:
p_instance_dict = aggregate_instances(pred_labels, threshold=0.5)

## Compare ground truth and predictions

- True positive if predicted instance is the first to overlap with true instance 
- False positive if predicted instance overlaps with true instance that has already been counted 

In [None]:
def is_match(inst, p_inst): 

  for c1 in p_inst['coords']: 
    for c2 in inst['coords']: 
      if rect_intersect(c1, c2): 
        return True 
  return False 

In [None]:
def pr(threshold, true_labels, pred_labels): 
    """Calculate Precision and Recall at given confidence threshold"""

  instance_dict = aggregate_instances(true_labels, threshold=threshold)
  p_instance_dict = aggregate_instances(pred_labels, threshold=threshold)

  for loc, sdict in instance_dict.items():
    for season, instances in sdict.items():
      p_instances = [p for p in p_instance_dict[loc][season] 
                    if p['conf'] >= threshold]

      for p_inst in p_instances: 
        p_inst['matched'] = False 

      for inst in instances: 
        
        inst['matched'] = False
        for p_inst in p_instances:  # did predictions find it 
          if p_inst['start'] > inst['end'] or inst['start'] > p_inst['end']: 
            continue # dates don't correspond
          if is_match(inst, p_inst): 
            # found it
            p_inst['matched'] = True 
            inst['matched'] = True 

  # How many matched
  tp, fp, fn = 0, 0, 0
  for loc, sdict in instance_dict.items():
    for season, instances in sdict.items():
      for inst in instances: 
        if inst['matched']: 
          tp += 1 
        else: 
          fn += 1
      for p_inst in [p for p in p_instance_dict[loc][season] 
                    if p['conf'] >= threshold]: 
        if not p_inst['matched']: 
          fp += 1

  precision = tp / (tp + fp) if tp + fp > 0 else 1 
  recall = tp / (tp + fn) if tp + fn > 0 else 1

  return precision, recall 

        

In [None]:
# Find all confidences 
all_confs = []
for v in rcnn_labels.values():
  for label in v.values():
    for arr in label.values():
      for a in arr: 
        all_confs.extend(a[2])
all_confs = np.unique([round(c, 3) for c in all_confs])        

In [None]:
precisions, recalls = [], []
for c in tqdm(np.sort(all_confs)): 

  p, r = pr(c, true_labels, rcnn_labels)
  precisions.append(p)
  recalls.append(r)


In [None]:
# save statistics 
import pickle 

stats = {'precisions': precisions, 'recalls': recalls}

with open('/gdrive/MyDrive/land-app/stats/yolov5_event_pr.p', 'wb') as f: 
  pickle.dump(stats, f)

In [None]:
with open('/gdrive/MyDrive/land-app/stats/yolov5_event_pr.p', 'rb') as f: 
  yolo_res = pickle.load(f)

plt.plot(yolo_res['recalls'], yolo_res['precisions'], lw=3, color='tab:olive', label=f'YOLOv5, AUC 0.63')
plt.legend()
plt.xlabel('Recall')
plt.ylabel("Precision")
