Skip to content

Commit

Permalink
Fixes device mismatch issue while building docs (#5428) (#5429)
Browse files Browse the repository at this point in the history
Co-authored-by: vfdev <vfdev.5@gmail.com>
  • Loading branch information
NicolasHug and vfdev-5 committed Feb 16, 2022
1 parent 2662797 commit 9cfb0b7
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions torchvision/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,9 @@ def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
"""

N, _, H, W = normalized_flow.shape
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8)
colorwheel = _make_colorwheel() # shape [55x3]
device = normalized_flow.device
flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
colorwheel = _make_colorwheel().to(device) # shape [55x3]
num_cols = colorwheel.shape[0]
norm = torch.sum(normalized_flow ** 2, dim=1).sqrt()
a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
Expand Down

0 comments on commit 9cfb0b7

Please sign in to comment.