From 723c2408654032babd078de10f81abf28ebb4222 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 24 Nov 2020 16:40:11 +0000 Subject: [PATCH] Adding Python type hints, correcting incorrect types, removing unnecessary vars and simplifying code. --- torchvision/models/detection/anchor_utils.py | 24 ++++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 9cf27834f75..defee6e5ae2 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import torch -from torch import nn +from torch import nn, Tensor from torch.jit.annotations import List, Optional, Dict from .image_list import ImageList @@ -56,8 +56,8 @@ def __init__( # For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values. # (scales, aspect_ratios) are usually an element of zip(self.scales, self.aspect_ratios) # This method assumes aspect ratio = height / width for an anchor. - def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="cpu"): - # type: (List[int], List[float], int, Device) -> Tensor # noqa: F821 + def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype: torch.dtype = torch.float32, + device: torch.device = torch.device("cpu")): scales = torch.as_tensor(scales, dtype=dtype, device=device) aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device) h_ratios = torch.sqrt(aspect_ratios) @@ -69,8 +69,7 @@ def generate_anchors(self, scales, aspect_ratios, dtype=torch.float32, device="c base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2 return base_anchors.round() - def set_cell_anchors(self, dtype, device): - # type: (int, Device) -> None # noqa: F821 + def set_cell_anchors(self, dtype: torch.dtype, device: torch.device): if self.cell_anchors is not None: cell_anchors = self.cell_anchors assert cell_anchors is not None @@ -95,8 +94,7 @@ def num_anchors_per_location(self): # For every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:2), # output g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a. - def grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: anchors = [] cell_anchors = self.cell_anchors assert cell_anchors is not None @@ -134,8 +132,7 @@ def grid_anchors(self, grid_sizes, strides): return anchors - def cached_grid_anchors(self, grid_sizes, strides): - # type: (List[List[int]], List[List[Tensor]]) -> List[Tensor] + def cached_grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) -> List[Tensor]: key = str(grid_sizes) + str(strides) if key in self._cache: return self._cache[key] @@ -143,8 +140,7 @@ def cached_grid_anchors(self, grid_sizes, strides): self._cache[key] = anchors return anchors - def forward(self, image_list, feature_maps): - # type: (ImageList, List[Tensor]) -> List[Tensor] + def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]: grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps]) image_size = image_list.tensors.shape[-2:] dtype, device = feature_maps[0].dtype, feature_maps[0].device @@ -153,10 +149,8 @@ def forward(self, image_list, feature_maps): self.set_cell_anchors(dtype, device) anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides) anchors = torch.jit.annotate(List[List[torch.Tensor]], []) - for i, (image_height, image_width) in enumerate(image_list.image_sizes): - anchors_in_image = [] - for anchors_per_feature_map in anchors_over_all_feature_maps: - anchors_in_image.append(anchors_per_feature_map) + for i in range(len(image_list.image_sizes)): + anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps] anchors.append(anchors_in_image) anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors] # Clear the cache in case that memory leaks.