# Match function return indexes of gt boxes indexes

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [32]:
def point_form(boxes):
    """
    Convert prior_boxes to (xmin, ymin, xmax, ymax)
        args: 
            boxes: (tensor of shape (num_boxes, 4))
            image_size: Size of square image
        return:
        boxes (tensor of shape (num_boxes, 4)) 
    """
    center  = boxes[:, :2]
    align = boxes[:, 2:] / 2
    top_left = center - align
    bottom_right = center + align
    
    return torch.cat([top_left, bottom_right], dim=-1)

def jaccard(priors, truths):
    """
    calculate IoU of every default box with every gt box
        args:
            truths: tensor of shape (num_objects, 4)
            priors: tensor of shape (8732,        4)
    """
    # unsqueeze dim 0 of truths and dim 1 of priors to compare every box of priors with every box of truth
    # both are broadcasted (num_priors, num_objects, 4)
    num_priors = priors.size(0)
    num_objects = truths.size(0)
    truths = truths.unsqueeze(0).expand(num_priors, -1, -1)
    priors = priors.unsqueeze(1).expand(-1, num_objects, -1)
    x1y1 = torch.max(truths[..., :2], priors[..., :2])
    x2y2 = torch.min(truths[..., 2:], priors[..., 2:])
    inter_area = (x2y2 - x1y1).clamp(min=0).prod(dim=-1)
    truth_area = (truths[..., 2] - truths[..., 0]) * (truths[..., 3] - truths[..., 1])
    prior_area = (priors[..., 2] - priors[..., 0]) * (priors[..., 3] - priors[..., 1])
    eps = 1e-7
    return inter_area / (prior_area + truth_area - inter_area + eps)

In [33]:
def match(threshold, truths, priors):
    """
        return a 1d Tensor, ith position is the index of gt boxes
        match with i-th prior box 
        Args:
            truths: (tensor) shape [num_objects, 5] (xmin, ymin, xmax, ymax, labels)
            priors: (tensor) shape [num_priors, 4] (cx, cy, h, w)
    """
    truths = truths[..., :-1]
    
    overlaps = jaccard(point_form(priors), truths)
    
    best_gt_scores, best_gt_indexes = overlaps.max(dim=1)
    
    best_prior_scores, best_prior_indexes = overlaps.max(dim=0)
    
    #this guarantees each gt box has atleast matches 1 prior box 
    for k in range(best_prior_indexes.size(0)):
        best_gt_scores[best_prior_indexes[k]] = 2.0
        best_gt_indexes[best_prior_indexes[k]] = k
    best_gt_indexes = best_gt_indexes + 1
    #take only those with scores above threshold value
    best_gt_indexes[best_gt_scores < threshold] = 0 # 0 means background
    
    return best_gt_indexes
    
    

In [34]:
truths = torch.tensor([
        [0.1, 0.1, 0.4, 0.4, 1],
        [0.6, 0.6, 0.9, 0.9, 2], 
    ])

priors = torch.tensor([
    [0.25, 0.25, 0.3, 0.3],
    [0.75, 0.75, 0.3, 0.3],
    [0.50, 0.50, 0.2, 0.2],
    [0.10, 0.90, 0.2, 0.2],
])

In [35]:
scores = jaccard(point_form(priors), truths[..., :-1])

In [23]:
scores

tensor([[1.0000, 0.0000],
        [0.0000, 1.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000]])

In [22]:
scores.max(dim=1)

torch.return_types.max(
values=tensor([1.0000, 1.0000, 0.0000, 0.0000]),
indices=tensor([0, 1, 0, 0]))