In [14]:
import os
import torch
import torchvision
import numpy as np

from matching import Matcher

In [15]:
data_dirpath='D:/data/mask_rcnn'
fg_iou_thresh=0.5
bg_iou_thresh=0.5

device=torch.device("cpu")
roi_head_assign_targets_to_proposal=torch.load(os.path.join(data_dirpath, "roi_head_assign_targets_to_proposal.pt"),map_location=device, weights_only=True)
gt_boxes=roi_head_assign_targets_to_proposal['gt_boxes']
gt_labels=roi_head_assign_targets_to_proposal['gt_labels']
proposals=roi_head_assign_targets_to_proposal['proposals']
print('gt_boxes ', [g.shape for g in gt_boxes])
print('gt_labels ', [g.shape for g in gt_labels])
print('proposals ', [p.shape for p in proposals])

proposal_matcher = Matcher(fg_iou_thresh, bg_iou_thresh, allow_low_quality_matches=False)

gt_boxes  [torch.Size([2, 4]), torch.Size([5, 4])]
gt_labels  [torch.Size([2]), torch.Size([5])]
proposals  [torch.Size([217415, 4]), torch.Size([217418, 4])]


In [18]:
matched_idxs, labels=[],[]

for proposals_in_image, gt_boxes_in_image, gt_labels_in_image in zip(proposals, gt_boxes, gt_labels):
    print('gt_boxes_in_image ', gt_boxes_in_image.shape)
    print('proposals_in_image ', proposals_in_image.shape)
    if gt_boxes_in_image.numel()==0:
        # background image
        device=proposals_in_image.device
        clamped_matched_idxs_in_image=torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64,
                                                 device=device)
        labels_in_image=torch.zeros((proposals_in_image.shape[0],), dtype=torch.int64, device=device)
    else:
        # MxN from Mx4 and Nx4
        match_quality_matrix=torchvision.ops.box_iou(gt_boxes_in_image,proposals_in_image) 
        matched_idxs_in_image=proposal_matcher(match_quality_matrix) # N

        clamped_matched_idxs_in_image=matched_idxs_in_image.clamp(min=0)

        labels_in_image=gt_labels_in_image[clamped_matched_idxs_in_image]
        labels_in_image=labels_in_image.to(dtype=torch.int64)

        # label background (below the low threshold)
        bg_inds=matched_idxs_in_image==proposal_matcher.BELOW_LOW_THRESHOLD
        labels_in_image[bg_inds]=0

        # Label ignore proposals (between low and high thresholds)
        ignore_inds=matched_idxs_in_image==proposal_matcher.BETWEEN_THRESHOLDS
        labels_in_image[ignore_inds]=-1 # -1 is ignored by sampler
        
    matched_idxs.append(clamped_matched_idxs_in_image)
    labels.append(labels_in_image)

print('\nmatched_idxs ', [(m.shape,m.min(), m.max()) for m in matched_idxs])
print('labels ', [(l.shape, l.min(), l.max()) for l in labels])

gt_boxes_in_image  torch.Size([2, 4])
proposals_in_image  torch.Size([217415, 4])
gt_boxes_in_image  torch.Size([5, 4])
proposals_in_image  torch.Size([217418, 4])

matched_idxs  [(torch.Size([217415]), tensor(0), tensor(1)), (torch.Size([217418]), tensor(0), tensor(4))]
labels  [(torch.Size([217415]), tensor(0), tensor(1)), (torch.Size([217418]), tensor(0), tensor(1))]
