# FCOS_TARGET

In [1]:
import torch

def compute_locations(features, strides, dense_points=1):
    locations = []
    for level, feature in enumerate(features):
        h, w = feature.size()[-2:]
        locations_per_lever = compute_locations_per_lever(h, w, strides[level], feature.device, dense_points)
        locations.append(locations_per_lever)
    return locations

def compute_locations_per_lever(h, w, stride, device, dense_points=1):
    shifts_x = torch.arange(0, w * stride, step=stride, dtype=torch.float32, device=device)
    shifts_y = torch.arange(0, h * stride, step=stride, dtype=torch.float32, device=device)
    shift_y, shift_x = torch.meshgrid((shifts_y, shifts_x))
    shift_x = shift_x.reshape(-1)
    shift_y = shift_y.reshape(-1)
    locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
    #locations = get_dense_locations(locations, stride, dense_points, device)
    return locations

h = [48]
w = [48]
stride = []
features = torch.rand(2, 128, 48, 48)
strides = [8, 16, 22]
compute_locations_per_lever(48, 48, 8, device=features.device)

tensor([[  4.,   4.],
        [ 12.,   4.],
        [ 20.,   4.],
        ...,
        [364., 380.],
        [372., 380.],
        [380., 380.]])

In [2]:

loc1 = torch.rand(6, 2)
loc2 = torch.rand(4, 2)
locations = torch.cat([loc1, loc2], dim=0)
print(locations.shape)

torch.Size([10, 2])


In [18]:

loc_ranges = [[-1, 64], [64, 128]]
expanded_loc_ranges = []
locations = [loc1, loc2]
for i in range(len(locations)):
    expanded_loc_ranges.append(locations[i].new_tensor(loc_ranges[i])[None].expand(len(locations[i]), -1))
expanded_loc_ranges
loc_ranges = torch.cat(expanded_loc_ranges, dim=0)
loc_ranges

tensor([[ -1.,  64.],
        [ -1.,  64.],
        [ -1.,  64.],
        [ -1.,  64.],
        [ -1.,  64.],
        [ -1.,  64.],
        [ 64., 128.],
        [ 64., 128.],
        [ 64., 128.],
        [ 64., 128.]])

In [16]:
loc1.new_tensor([64, 128])[None].expand(2, -1)

tensor([[ 64., 128.],
        [ 64., 128.]])

# fcos_match

In [27]:
gt = torch.rand(2,4)
points = torch.randn(4, 2)
xs = points[:, 1]
xs.shape

torch.Size([4])

In [29]:
gt_xs = xs[:, None].expand(4, 6)
gt_xs.shape

torch.Size([4, 6])

In [33]:
points.shape
points[None]

tensor([[[-0.4014,  0.3728],
         [ 0.5106, -0.2264],
         [ 1.0799, -0.5658],
         [-0.3043, -0.4544]]])

In [None]:
def fckp_match(
        points,    #(K, 2)
        gt,        #(n, 8)
        loc_ranges,#(25, 2)
        num_points_per,#[5] W*H
        cfg,
        strides=[8, 16, 32, 64, 128],
        ig=None):
    #一张图上所有的点匹配gt
    INF = 1e10
    num_gts = gt.shape[0]
    K = points.shape[0]
    gt_labels = gt[:, 4]
    xs, ys = points[:, 0], points[:, 1]
    gt_bboxes = gt[:, :4]
    areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0] + 1) * (gt_bboxes[:, 3] - gt_bboxes[:, 1] + 1)

    areas = areas[None].repeat(K, 1)
    loc_ranges = loc_ranges[:, None, :].expand(K, num_gts, 2)
    gt_bboxes = gt_bboxes[None].expand(K, num_gts, 4)
    gt_xs = xs[:, None].expand(K, num_gts)
    gt_ys = ys[:, None].expand(K, num_gts) #扩展到某一维(直接重复)(K)-->(k, num_gts)

    left = gt_xs - gt_bboxes[..., 0]
    right = gt_bboxes[..., 2] - gt_xs
    top = gt_ys - gt_bboxes[..., 1]
    bottom = gt_bboxes[..., 3] - gt_ys
    bbox_targets = torch.stack((left, top, right, bottom), -1)

    if cfg.get('center_sample', False):
        sample_mask = get_sample_region(gt_bboxes, strides, num_points_per, gt_xs, gt_ys, radius=cfg.get('pos_radius', 1)) # noqa E501
    else:
        sample_mask = bbox_targets.min(-1)[0] > 0

    max_loc_distance = bbox_targets.max(-1)[0]
    inside_loc_range = (max_loc_distance >= loc_ranges[..., 0]) & (max_loc_distance <= loc_ranges[..., 1])

    # if there are still more than one objects for a location,
    # we choose the one with minimal area
    areas[sample_mask == 0] = INF

    areas[inside_loc_range == 0] = INF
    min_area, min_area_inds = areas.min(dim=1)
    labels = gt_labels[min_area_inds]
    labels[min_area == INF] = 0
    bbox_targets = bbox_targets[range(K), min_area_inds]

    return labels, bbox_targets