Skip to content

Commit

Permalink
Fix deprecation warning in nonzero (#2705)
Browse files Browse the repository at this point in the history
Replace nonzero by where, now that it works with just a condition
  • Loading branch information
fmassa committed Sep 24, 2020
1 parent 6a43a1f commit 15848ed
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 16 deletions.
8 changes: 4 additions & 4 deletions torchvision/models/detection/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def __call__(self, matched_idxs):
pos_idx = []
neg_idx = []
for matched_idxs_per_image in matched_idxs:
positive = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
negative = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)
positive = torch.where(matched_idxs_per_image >= 1)[0]
negative = torch.where(matched_idxs_per_image == 0)[0]

num_pos = int(self.batch_size_per_image * self.positive_fraction)
# protect against not enough positive examples
Expand Down Expand Up @@ -317,7 +317,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
# For each gt, find the prediction with which it has highest quality
highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1)
# Find highest quality match available, even if it is low, including ties
gt_pred_pairs_of_highest_quality = torch.nonzero(
gt_pred_pairs_of_highest_quality = torch.where(
match_quality_matrix == highest_quality_foreach_gt[:, None]
)
# Example gt_pred_pairs_of_highest_quality:
Expand All @@ -334,7 +334,7 @@ def set_low_quality_matches_(self, matches, all_matches, match_quality_matrix):
# Each row is a (gt index, prediction index)
# Note how gt items 1, 2, 3, and 5 each have two ties

pred_inds_to_update = gt_pred_pairs_of_highest_quality[:, 1]
pred_inds_to_update = gt_pred_pairs_of_highest_quality[1]
matches[pred_inds_to_update] = all_matches[pred_inds_to_update]


Expand Down
2 changes: 1 addition & 1 deletion torchvision/models/detection/generalized_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def forward(self, images, targets=None):
degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
if degenerate_boxes.any():
# print the first degenerate box
bb_idx = degenerate_boxes.any(dim=1).nonzero().view(-1)[0]
bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
degen_bb: List[float] = boxes[bb_idx].tolist()
raise ValueError("All bounding boxes should have positive height and width."
" Found invalid box {} for target at index {}."
Expand Down
12 changes: 6 additions & 6 deletions torchvision/models/detection/roi_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def fastrcnn_loss(class_logits, box_regression, labels, regression_targets):
# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
sampled_pos_inds_subset = torch.where(labels > 0)[0]
labels_pos = labels[sampled_pos_inds_subset]
N, num_classes = class_logits.shape
box_regression = box_regression.reshape(N, -1, 4)
Expand Down Expand Up @@ -296,7 +296,7 @@ def keypointrcnn_loss(keypoint_logits, proposals, gt_keypoints, keypoint_matched

keypoint_targets = torch.cat(heatmaps, dim=0)
valid = torch.cat(valid, dim=0).to(dtype=torch.uint8)
valid = torch.nonzero(valid).squeeze(1)
valid = torch.where(valid)[0]

# torch.mean (in binary_cross_entropy_with_logits) does'nt
# accept empty tensors, so handle it sepaartely
Expand Down Expand Up @@ -604,7 +604,7 @@ def subsample(self, labels):
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
zip(sampled_pos_inds, sampled_neg_inds)
):
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
img_sampled_inds = torch.where(pos_inds_img | neg_inds_img)[0]
sampled_inds.append(img_sampled_inds)
return sampled_inds

Expand Down Expand Up @@ -700,7 +700,7 @@ def postprocess_detections(self,
labels = labels.reshape(-1)

# remove low scoring boxes
inds = torch.nonzero(scores > self.score_thresh).squeeze(1)
inds = torch.where(scores > self.score_thresh)[0]
boxes, scores, labels = boxes[inds], scores[inds], labels[inds]

# remove empty boxes
Expand Down Expand Up @@ -784,7 +784,7 @@ def forward(self,
mask_proposals = []
pos_matched_idxs = []
for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
pos = torch.where(labels[img_id] > 0)[0]
mask_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
Expand Down Expand Up @@ -832,7 +832,7 @@ def forward(self,
pos_matched_idxs = []
assert matched_idxs is not None
for img_id in range(num_images):
pos = torch.nonzero(labels[img_id] > 0).squeeze(1)
pos = torch.where(labels[img_id] > 0)[0]
keypoint_proposals.append(proposals[img_id][pos])
pos_matched_idxs.append(matched_idxs[img_id][pos])
else:
Expand Down
4 changes: 2 additions & 2 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,8 @@ def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets)
"""

sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_pos_inds = torch.where(torch.cat(sampled_pos_inds, dim=0))[0]
sampled_neg_inds = torch.where(torch.cat(sampled_neg_inds, dim=0))[0]

sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)

Expand Down
2 changes: 1 addition & 1 deletion torchvision/ops/boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def remove_small_boxes(boxes: Tensor, min_size: float) -> Tensor:
"""
ws, hs = boxes[:, 2] - boxes[:, 0], boxes[:, 3] - boxes[:, 1]
keep = (ws >= min_size) & (hs >= min_size)
keep = keep.nonzero().squeeze(1)
keep = torch.where(keep)[0]
return keep


Expand Down
4 changes: 2 additions & 2 deletions torchvision/ops/poolers.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _onnx_merge_levels(levels: Tensor, unmerged_results: List[Tensor]) -> Tensor
first_result.size(2), first_result.size(3)),
dtype=dtype, device=device)
for level in range(len(unmerged_results)):
index = (levels == level).nonzero().view(-1, 1, 1, 1)
index = torch.where(levels == level)[0].view(-1, 1, 1, 1)
index = index.expand(index.size(0),
unmerged_results[level].size(1),
unmerged_results[level].size(2),
Expand Down Expand Up @@ -234,7 +234,7 @@ def forward(

tracing_results = []
for level, (per_level_feature, scale) in enumerate(zip(x_filtered, scales)):
idx_in_level = torch.nonzero(levels == level).squeeze(1)
idx_in_level = torch.where(levels == level)[0]
rois_per_level = rois[idx_in_level]

result_idx_in_level = roi_align(
Expand Down

0 comments on commit 15848ed

Please sign in to comment.