diff --git a/references/detection/presets.py b/references/detection/presets.py index 1fac69ae356..04e0680043a 100644 --- a/references/detection/presets.py +++ b/references/detection/presets.py @@ -1,3 +1,5 @@ +import torch + import transforms as T @@ -6,7 +8,8 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): if data_augmentation == 'hflip': self.transforms = T.Compose([ T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), ]) elif data_augmentation == 'ssd': self.transforms = T.Compose([ @@ -14,13 +17,15 @@ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123., 117., 104.)): T.RandomZoomOut(fill=list(mean)), T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), ]) elif data_augmentation == 'ssdlite': self.transforms = T.Compose([ T.RandomIoUCrop(), T.RandomHorizontalFlip(p=hflip_prob), - T.ToTensor(), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), ]) else: raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"') diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 6659e82f01c..c65535750b5 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,10 +1,10 @@ +from typing import List, Tuple, Dict, Optional + import torch import torchvision - from torch import nn, Tensor from torchvision.transforms import functional as F from torchvision.transforms import transforms as T -from typing import List, Tuple, Dict, Optional def _flip_coco_person_keypoints(kps, width): @@ -52,6 +52,24 @@ def forward(self, image: Tensor, return image, target +class PILToTensor(nn.Module): + def forward(self, image: Tensor, + target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + image = F.pil_to_tensor(image) + return image, target + + +class ConvertImageDtype(nn.Module): + def __init__(self, dtype: torch.dtype) -> None: + super().__init__() + self.dtype = dtype + + def forward(self, image: Tensor, + target: Optional[Dict[str, Tensor]] = None) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]: + image = F.convert_image_dtype(image, self.dtype) + return image, target + + class RandomIoUCrop(nn.Module): def __init__(self, min_scale: float = 0.3, max_scale: float = 1.0, min_aspect_ratio: float = 0.5, max_aspect_ratio: float = 2.0, sampler_options: Optional[List[float]] = None, trials: int = 40): diff --git a/references/segmentation/presets.py b/references/segmentation/presets.py index 3bf29c23751..96334356fcb 100644 --- a/references/segmentation/presets.py +++ b/references/segmentation/presets.py @@ -1,3 +1,5 @@ +import torch + import transforms as T @@ -11,7 +13,8 @@ def __init__(self, base_size, crop_size, hflip_prob=0.5, mean=(0.485, 0.456, 0.4 trans.append(T.RandomHorizontalFlip(hflip_prob)) trans.extend([ T.RandomCrop(crop_size), - T.ToTensor(), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ]) self.transforms = T.Compose(trans) @@ -24,7 +27,8 @@ class SegmentationPresetEval: def __init__(self, base_size, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)): self.transforms = T.Compose([ T.RandomResize(base_size, base_size), - T.ToTensor(), + T.PILToTensor(), + T.ConvertImageDtype(torch.float), T.Normalize(mean=mean, std=std), ]) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 943694d3a5c..cf4846a1c27 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,7 +1,6 @@ -import numpy as np -from PIL import Image import random +import numpy as np import torch from torchvision import transforms as T from torchvision.transforms import functional as F @@ -75,14 +74,22 @@ def __call__(self, image, target): return image, target -class ToTensor(object): +class PILToTensor: def __call__(self, image, target): image = F.pil_to_tensor(image) - image = F.convert_image_dtype(image) target = torch.as_tensor(np.array(target), dtype=torch.int64) return image, target +class ConvertImageDtype: + def __init__(self, dtype): + self.dtype = dtype + + def __call__(self, image, target): + image = F.convert_image_dtype(image, self.dtype) + return image, target + + class Normalize(object): def __init__(self, mean, std): self.mean = mean