From cefc190a44ad5d92248a1359367cbeb5d54cb51a Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 9 Feb 2022 18:11:54 +0000 Subject: [PATCH] Refactor Augmentation Space calls to speed up. --- torchvision/transforms/autoaugment.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index d58077c9b14..a6109ad7030 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -268,9 +268,9 @@ def forward(self, img: Tensor) -> Tensor: transform_id, probs, signs = self.get_params(len(self.policies)) + op_meta = self._augmentation_space(10, F.get_image_size(img)) for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]): if probs[i] <= p: - op_meta = self._augmentation_space(10, F.get_image_size(img)) magnitudes, signed = op_meta[op_name] magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 if signed and signs[i] == 0: @@ -350,8 +350,8 @@ def forward(self, img: Tensor) -> Tensor: elif fill is not None: fill = [float(f) for f in fill] + op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) for _ in range(self.num_ops): - op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img)) op_index = int(torch.randint(len(op_meta), (1,)).item()) op_name = list(op_meta.keys())[op_index] magnitudes, signed = op_meta[op_name]