In [1]:
import os
import torch
import numpy as np

In [3]:
class BalancedPositiveNegativeSampler:
    """
    This class samples batches, ensuring that they contain a fixed proportion of positives
    """
    def __init__(self, batch_size_per_image, positive_fraction):
        """
        Args:
            batch_size_per_image (int): number of elements to be selected per image
            positive_fraction (float): percentage of positive elements per batch
        """
        self.batch_size_per_image=batch_size_per_image
        self.positive_fraction=positive_fraction

    def __call__(self, matched_idxs):
        """
        Returns two lists of binary masks for each image. The first list contains the positive elements that 
        were selected, and the second list the negative example
        Args:
            matched_idxs: list of tensors containing -1, 0. or positive values.
                Each tensor corresponds to a specific image.
                -1 values are ignored, 0 are considered as negatives and >0 as positives
        Returns:
            pos_idx (list[tensor])
            neg_idx (list[tensor])
        """
        pass

In [5]:
data_dirpath='D:/data/mask_rcnn'

device=torch.device("cpu")
labels=torch.load(os.path.join(data_dirpath, "labels.pt"),map_location=device, weights_only=True)
print('labels ', len(labels), [(l.shape, l.min(), l.max()) for l in labels])
batch_size_per_image, positive_fraction=256, 0.5
fg_bg_sampler=BalancedPositiveNegativeSampler(batch_size_per_image, positive_fraction)

labels  2 [(torch.Size([198249]), tensor(-1.), tensor(1.)), (torch.Size([198249]), tensor(-1.), tensor(1.))]


```
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
```

In [32]:
matched_idxs=labels

pos_idx, neg_idx=[],[]

for matched_idxs_per_image in matched_idxs: 
    # matched_idxs_per_image is 1D tensor
    positive=torch.nonzero(matched_idxs_per_image>=1, as_tuple=True)[0]
    negative=torch.nonzero(matched_idxs_per_image==0, as_tuple=True)[0]

    # number of positive samples
    num_pos=int(fg_bg_sampler.batch_size_per_image*fg_bg_sampler.positive_fraction)
    # protect against not enough positive examples
    num_pos=min(positive.numel(), num_pos)
    num_neg=fg_bg_sampler.batch_size_per_image-num_pos
    # protect against not enough negative examples
    num_neg=min(negative.numel(), num_neg)

    # randomly select positive and negative examples
    perm1=torch.randperm(positive.numel(), device=positive.device)[:num_pos]
    perm2=torch.randperm(negative.numel(), device=negative.device)[:num_neg]

    pos_idx_per_image=positive[perm1]
    neg_idx_per_image=negative[perm2]

    # create binary mask from indices
    pos_idx_per_image_mask=torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
    neg_idx_per_image_mask=torch.zeros_like(matched_idxs_per_image, dtype=torch.uint8)
    pos_idx_per_image_mask[pos_idx_per_image]=1
    neg_idx_per_image_mask[neg_idx_per_image]=1

    pos_idx.append(pos_idx_per_image_mask)
    neg_idx.append(neg_idx_per_image_mask)

print('\npos_idx ', [(p.shape, p.min(), p.max()) for p in pos_idx])
print('\nneg_idx ', [(p.shape, p.min(), p.max()) for p in neg_idx])


pos_idx  [(torch.Size([198249]), tensor(0, dtype=torch.uint8), tensor(1, dtype=torch.uint8)), (torch.Size([198249]), tensor(0, dtype=torch.uint8), tensor(1, dtype=torch.uint8))]

neg_idx  [(torch.Size([198249]), tensor(0, dtype=torch.uint8), tensor(1, dtype=torch.uint8)), (torch.Size([198249]), tensor(0, dtype=torch.uint8), tensor(1, dtype=torch.uint8))]
