From fc1a060521f5d121c4e7a80347d61feab9c13c52 Mon Sep 17 00:00:00 2001 From: Sofiane Abbar Date: Tue, 11 May 2021 17:19:22 +0100 Subject: [PATCH] replaced deprecated call to ByteTensor with from_numpy replaced byteTensor with from_numpy fixed lint issues and copy related worning --- torchvision/transforms/functional.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 5bbd91b3fd8..6a86a000d65 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -124,17 +124,13 @@ def to_tensor(pic): return torch.from_numpy(nppic).to(dtype=default_float_dtype) # handle PIL Image - if pic.mode == 'I': - img = torch.from_numpy(np.array(pic, np.int32, copy=False)) - elif pic.mode == 'I;16': - img = torch.from_numpy(np.array(pic, np.int16, copy=False)) - elif pic.mode == 'F': - img = torch.from_numpy(np.array(pic, np.float32, copy=False)) - elif pic.mode == '1': - img = 255 * torch.from_numpy(np.array(pic, np.uint8, copy=False)) - else: - img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) + mode_to_nptype = {'I': np.int32, 'I;16': np.int16, 'F': np.float32} + img = torch.from_numpy( + np.array(pic, mode_to_nptype.get(pic.mode, np.uint8), copy=True) + ) + if pic.mode == '1': + img = 255 * img img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) # put it from HWC to CHW format img = img.permute((2, 0, 1)).contiguous()