In [0]:
import torch

def box_iou(box1, box2):
    '''Compute the intersection over union of two set of boxes.
    The default box order is (xmin, ymin, xmax, ymax).
    Args:
      box1: (tensor) bounding boxes, sized [N,4].
      box2: (tensor) bounding boxes, sized [M,4].
    Return:
      (tensor) iou, sized [N,M].
    Reference:
      https://github.com/chainer/chainercv/blob/master/chainercv/utils/bbox/bbox_iou.py
    '''

    N = box1.size(0)
    M = box2.size(0)

    lt = torch.max(box1[:,None,:2], box2[:,:2])  # [N,M,2]
    rb = torch.min(box1[:,None,2:], box2[:,2:])  # [N,M,2]

    wh = (rb-lt+1).clamp(min=0)      # [N,M,2]
    inter = wh[:,:,0] * wh[:,:,1]  # [N,M]

    area1 = (box1[:,2]-box1[:,0]+1) * (box1[:,3]-box1[:,1]+1)  # [N,]
    area2 = (box2[:,2]-box2[:,0]+1) * (box2[:,3]-box2[:,1]+1)  # [M,]
    iou = inter / (area1[:,None] + area2 - inter)
    return iou

def test():
  box1 = torch.tensor([
      [0,0,9,9],
      [100,100,200,200]
  ]).float()
  box2 = torch.tensor([
      [0,0,4,4],
      [100,100,200,200],
      [5,5,150,150]
  ]).float()
  print(box_iou(box1, box2))
  
test()


tensor([[0.2500, 0.0000, 0.0012],
        [0.0000, 1.0000, 0.0900]])


In [0]:
def box2loc(src_box,ref_box):
  '''Convert src_box(x1,y1,x2,y2) to loc(cx,cy,w,h) 
    Args:
      src_box: (tensor) bounding boxes, sized [N,4].
      ref_box: (tensor) bounding boxes, sized [N,4].
    Return:
      (tensor) locs, sized [N,4].
    '''
  ref_box_wh = ref_box[:,2:] - ref_box[:,:2] + 1.
  ref_box_cxy = ref_box[:,:2] + (ref_box_wh - 1.) * 0.5
  src_box_wh = src_box[:,2:] - src_box[:,:2] + 1.
  src_box_cxy = src_box[:,:2] + (src_box_wh - 1.) * 0.5
  
  loc_xy = (src_box_cxy-ref_box_cxy) / ref_box_wh
  loc_wh = torch.log(src_box_wh/ref_box_wh)
  
  return torch.cat([loc_xy,loc_wh],1)

def loc2box(src_loc,ref_box):
  '''Convert src_loc(cx,cy,w,h) to box(x1,y1,x2,y2) 
    Args:
      src_loc: (tensor) locs, sized [N,4].
      ref_box: (tensor) bounding boxes, sized [N,4].
    Return:
      (tensor) bounding boxes, sized [N,4].
    '''
  ref_box_wh = ref_box[:,2:] - ref_box[:,:2] + 1.
  ref_box_cxy = ref_box[:,:2] + (ref_box_wh - 1.) * 0.5
  
  box_cxy = src_loc[:,:2] * ref_box_wh + ref_box_cxy
  box_wh = src_loc[:,2:].exp() * ref_box_wh
  box_x1y1 = box_cxy - (box_wh - 1.) * 0.5
  box_x2y2 = box_cxy + (box_wh - 1.) * 0.5
  
  return torch.cat([box_x1y1,box_x2y2],1)
  
def test():
  box1 = torch.tensor([
      [0,0,9,9],
      [10,10,20,20]
  ]).float()
  box2 = torch.tensor([
      [0,0,4,4],
      [15,15,30,30]
  ]).float()
  loc = box2loc(box1, box2)
  box1d = loc2box(loc, box2)
  
  print(loc)
  print(torch.equal(box1,box1d))
  
  
test()

tensor([[ 0.5000,  0.5000,  0.6931,  0.6931],
        [-0.4688, -0.4688, -0.3747, -0.3747]])
True


In [0]:
def mk_target(gt_boxes, gt_labels, ref_boxes):
  '''Make training target
    Args:
      gt_boxes: (tensor) gt boxes, sized [M,4].
      gt_labels: (tensor) gt labels, sized [M,1].
      ref_boxes: (tensor) bounding boxes, sized [N,4].
    Return:
      (tensor) target_locs, sized [N,4].
      (tensor) target_cls, sized [N].
      (tensor) target_boxes, sized [N,4].
      (tensor) iou, sized [N,M].
      (tensor) max_iou, sized [N,].
      (tensor) max_ids, sized [N,].
  
  '''
  if(gt_boxes.size(0) > 0):
    iou = box_iou(ref_boxes, gt_boxes)
    max_iou, max_ids = torch.max(iou,1)
    target_boxes = gt_boxes[max_ids]
    target_locs = box2loc(target_boxes,ref_boxes)
    target_cls = gt_labels[max_ids]
  else:
    target_locs = torch.zeros((ref_boxes.size(0),4))
    target_boxes = torch.zeros((ref_boxes.size(0),4))
    target_cls = torch.zeros((ref_boxes.size(0)))
    iou = torch.tensor([])
    max_iou = torch.zeros(ref_boxes.size(0))
    max_ids = torch.tensor([])
  
  return target_locs, target_cls, target_boxes, iou, max_iou, max_ids
  
def select_target(ref_boxes, max_iou, iou,
                  num_sample=256, fg_ratio=0.25, 
                  positive_th=0.5, negative_th=0.5, negative_th_lo=0.0,
                  correct_region=[0,0,968,608],
                  large_size_ratio = 0.25, large_size_th = 256. * 256.,
                  use_gt_label=True):
  '''Select training target from taget_locs
  '''
  
  # remove outside region boxes
  if correct_region is not None:
    inside_region_ids = (ref_boxes[:,0] >= correct_region[0]) & \
                        (ref_boxes[:,1] >= correct_region[1]) & \
                        (ref_boxes[:,2] < correct_region[2]) & \
                        (ref_boxes[:,3] < correct_region[3])
  else:
    inside_region_ids = torch.ones(target_boxes.size(0)).byte()
  inside_boxes = ref_boxes[inside_region_ids]

  
  #select pos neg ids
  negative_ids = (max_iou < negative_th) & (max_iou >= negative_th_lo)
  positive_ids = (max_iou >= positive_th)
  negative_ids = negative_ids[inside_region_ids]
  positive_ids = positive_ids[inside_region_ids]
  if use_gt_label:
    if(iou[inside_region_ids].size(0) > 0):
      _, gt_ids = torch.max(iou[inside_region_ids],0)
      positive_ids[gt_ids] = 1

  negative_ids = negative_ids.nonzero().view(-1)
  positive_ids = positive_ids.nonzero().view(-1)

  
  #remap
  inside_region_ids = inside_region_ids.nonzero().view(-1)
  if(negative_ids.size(0)>0):
    negative_ids = inside_region_ids[negative_ids]
  if(positive_ids.size(0)>0):
    positive_ids = inside_region_ids[positive_ids]
    
  #select to match number
  positive_num = int(num_sample * fg_ratio)
  if(positive_ids.size(0) > positive_num):
    positive_ids = positive_ids[torch.randperm(positive_ids.size(0))[:positive_num]]
  negative_num = num_sample - positive_ids.size(0)
  
  if(negative_ids.size(0) > negative_num):
    
    #add large_region_to_train
    if(large_size_th is not None):
      negative_ids_tmp = negative_ids[torch.randperm(negative_ids.size(0))[:negative_num]]
    
      x1 = ref_boxes[:,0]
      y1 = ref_boxes[:,1]
      x2 = ref_boxes[:,2]
      y2 = ref_boxes[:,3]

      areas = (x2-x1+1) * (y2-y1+1)
      large_region_ids = (areas >= large_size_th)
      othre_region_ids = (areas < large_size_th)
      negative_large_region_ids = negative_ids[large_region_ids[negative_ids]]
      negative_othre_region_ids = negative_ids[othre_region_ids[negative_ids]]
      
      large_size_num = int(large_size_ratio * negative_num)
      
      negative_ids = negative_ids_tmp
      
      if(negative_large_region_ids.size(0) > 0):
        if(large_region_ids[negative_ids_tmp].sum() < large_size_num):
          negative_large_region_ids = negative_large_region_ids[torch.randperm(negative_large_region_ids.size(0))[:large_size_num]]
      
          othre_size_num = negative_num - negative_large_region_ids.size(0)
          negative_other_region_ids = negative_othre_region_ids[torch.randperm(negative_othre_region_ids.size(0))[:othre_size_num]]
          negative_ids = torch.cat([negative_large_region_ids, negative_other_region_ids],0)
        
    else:
      negative_ids = negative_ids[torch.randperm(negative_ids.size(0))[:negative_num]]
    
  return positive_ids, negative_ids
  
  
def test():
  gt_boxes = torch.tensor([
      [300-128,300-128,300+128,300+128],
      [100-128,100-128,100+128,100+128]
  ]).float()
  gt_labels = torch.tensor([
      [1],
      [2]
  ]).long()
  
  
  ref_boxes = torch.tensor([
      [300-128,300-128,300+128,300+128],
      [290-128,200-128,290+128,290+128],
      [200-128,200-128,200+128,200+128],
      [190-128,190-128,190+128,190+128],
      [180-128,180-128,180+128,180+128],
      [170-128,170-128,170+128,170+128],
      [100-128,100-128,100+128,100+128],
      [170-128,170-128,170+128,170+128],
      [170-128,170-128,170+128,170+128],
      [170-128,170-128,170+228,170+228]

  ]).float()
  
  target_locs, target_cls, target_boxes,iou, max_iou, max_ids \
     = mk_target(gt_boxes, gt_labels, ref_boxes)
  
  positive_ids, negative_ids = select_target(ref_boxes, max_iou, iou, num_sample=5,large_size_th=300.*300.)
  
  #print(target_cls)
  print(positive_ids)
  print(negative_ids)
  
test()
  

tensor([8])
tensor([7, 5, 9, 2])


In [0]:
def box_nms(bboxes, scores, threshold=0.3, mode='union'):
    '''Non maximum suppression.
    Args:
      bboxes: (tensor) bounding boxes, sized [N,4].
      scores: (tensor) bbox scores, sized [N,].
      threshold: (float) overlap threshold.
      mode: (str) 'union' or 'min'.
    Returns:
      keep: (tensor) selected indices.
    Reference:
      https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/nms/py_cpu_nms.py
    '''
    x1 = bboxes[:,0]
    y1 = bboxes[:,1]
    x2 = bboxes[:,2]
    y2 = bboxes[:,3]

    areas = (x2-x1+1) * (y2-y1+1)
    _, order = scores.sort(0, descending=True)

    keep = []
    while order.numel() > 0:
        i = order[0]
        keep.append(i)

        if order.numel() == 1:
            break

        xx1 = x1[order[1:]].clamp(min=x1[i])
        yy1 = y1[order[1:]].clamp(min=y1[i])
        xx2 = x2[order[1:]].clamp(max=x2[i])
        yy2 = y2[order[1:]].clamp(max=y2[i])

        w = (xx2-xx1+1).clamp(min=0)
        h = (yy2-yy1+1).clamp(min=0)
        inter = w*h

        if mode == 'union':
            ovr = inter / (areas[i] + areas[order[1:]] - inter)
        elif mode == 'min':
            ovr = inter / areas[order[1:]].clamp(max=areas[i])
        else:
            raise TypeError('Unknown nms mode: %s.' % mode)

        ids = (ovr<=threshold).nonzero().view(-1)
        if ids.numel() == 0:
            break
        order = order[ids+1]
    return torch.LongTensor(keep)
  
def test():
  
  device = torch.device('cuda')
  
  scores = torch.tensor([
      [0.9],
      [0.1],
      [0.8],
      [0.2],
      [0.7],
      [0.3],
      [0.6]
  ]).to(device)
  
  
  boxes = torch.tensor([
      [300-128,300-128,300+128,300+128],
      [290-128,200-128,290+128,290+128],
      [200-128,200-128,200+128,200+128],
      [190-128,190-128,190+128,190+128],
      [180-128,180-128,180+128,180+128],
      [170-128,170-128,170+128,170+128],
      [100-128,100-128,100+128,100+128]
  ]).float().to(device)
  
  keep = box_nms(boxes,scores.squeeze())
  print(keep)
  print(boxes[keep])

test()

tensor([0, 2, 6])
tensor([[172., 172., 428., 428.],
        [ 72.,  72., 328., 328.],
        [-28., -28., 228., 228.]], device='cuda:0')
