Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 9 additions & 15 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch import nn
from torch import nn, Tensor

from torch.jit.annotations import List, Optional, Dict
from .image_list import ImageList
Expand Down Expand Up @@ -56,8 +56,8 @@ def __init__(
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
# (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason why the dtype was defined as int earlier?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because torchscript didn't support torch.dtype before. Maybe it's supported now so we can switch to using it

def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32,
device: torch.device = torch.device("cpu")):
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
Expand All @@ -69,8 +69,7 @@ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="c
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()

def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear why dtype is declared as int and device as Device instead of torch.device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above, those were limitations of torchscript that might have since then been fixed

def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
Expand All @@ -95,8 +94,7 @@ def num_anchors_per_location(self):

# For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
Expand Down Expand Up @@ -134,17 +132,15 @@ def grid_anchors(self, grid_sizes, strides):

return anchors

def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]:
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors

def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
Expand All @@ -153,10 +149,8 @@ def forward(self, image_list, feature_maps):
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused vars.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This is due to some last minute refactoring before the release that kept those variables still around.

anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
for i in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
Expand Down