diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 92651b8d884..594cf3441a9 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -520,11 +520,10 @@ def adjust_gamma(img, gamma, gain=1): input_mode = img.mode img = img.convert('RGB') - np_img = np.array(img, dtype=np.float32) - np_img = 255 * gain * ((np_img / 255) ** gamma) - np_img = np.uint8(np.clip(np_img, 0, 255)) + gamma_map = [255 * gain * pow(ele / 255., gamma) for ele in range(256)] * 3 + img = img.point(gamma_map) # use PIL's point-function to accelerate this part - img = Image.fromarray(np_img, 'RGB').convert(input_mode) + img = img.convert(input_mode) return img