diff --git a/torchvision/utils.py b/torchvision/utils.py index 052215bcbd2..2e8fbccb243 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -90,6 +90,6 @@ def save_image(tensor, filename, nrow=8, padding=2, tensor = tensor.cpu() grid = make_grid(tensor, nrow=nrow, padding=padding, normalize=normalize, range=range, scale_each=scale_each) - ndarr = grid.mul(255).byte().transpose(0, 2).transpose(0, 1).numpy() + ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() im = Image.fromarray(ndarr) im.save(filename)