From e9fc63bc9b6aaf5acc10201af6468822315691c2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 18 Oct 2021 15:47:22 +0100 Subject: [PATCH] Pass indexing param to meshgrid --- test/test_io.py | 2 +- torchvision/models/detection/anchor_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_io.py b/test/test_io.py index 3c4de195285..c45180571f0 100644 --- a/test/test_io.py +++ b/test/test_io.py @@ -23,7 +23,7 @@ def _create_video_frames(num_frames, height, width): - y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width)) + y, x = torch.meshgrid(torch.linspace(-2, 2, height), torch.linspace(-2, 2, width), indexing="ij") data = [] for i in range(num_frames): xc = float(i) / num_frames diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 2e433958715..ec6f7dfa8e1 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -104,7 +104,7 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]]) # For output anchor, compute [x_center, y_center, x_center, y_center] shifts_x = torch.arange(0, grid_width, dtype=torch.int32, device=device) * stride_width shifts_y = torch.arange(0, grid_height, dtype=torch.int32, device=device) * stride_height - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) @@ -222,7 +222,7 @@ def _grid_default_boxes( shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1)