In [553]:
import torchvision
from PIL import Image

In [554]:
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import numpy as np
import torch

In [555]:
from torchvision.datasets import MNIST

In [556]:
try:
    import accimage
except ImportError:
    accimage = None
    

def _is_pil_image(img):
    if accimage is not None:
        return isinstance(img, (Image.Image, accimage.Image))
    else:
        return isinstance(img, Image.Image)
    
    
def _rgb2hsv(testing):
    max_vals, _ = torch.max(testing, dim=0)
    min_vals, _ = torch.min(testing, dim=0)
    v = max_vals
    s = (max_vals - min_vals) / max_vals
    df = max_vals - min_vals
    r, g, b = testing[0], testing[1], testing[2]
    
    
    h = max_vals != min_vals
    hr =  (max_vals == r) * ((g - b) / df) 
    hg =  (max_vals == g) * ((b - r) / df + 2.0)
    hb =  (max_vals == b) * ((r - g) / df + 4.0)    
    h = h * (hr + hg + hb)
    h = (h / 6.0) % 1.0
    h[h != h] = 0
    return torch.stack((h, s, v))
 
def _hsv2rgb(testing):
    h, s, v = testing[0], testing[1], testing[2]
    i = (h * 6.0).type(torch.IntTensor)
    f = (h * 6.0) - i
    p = v * (1.0 - s)
    q = v * (1.0 - f * s)
    t = v * (1.0 - (1 - f) * s)
    i = i % 6
    r = (i == 0) * v + (i == 1) * q +  (i == 2) * p + (i == 3) * p + (i == 4) * t + (i == 5) * v 
    g = (i == 0) * t + (i == 1) * v +  (i == 2) * v + (i == 3) * q + (i == 4) * p + (i == 5) * p
    b = (i == 0) * p + (i == 1) * p +  (i == 2) * t + (i == 3) * v + (i == 4) * v + (i == 5) * q

    return torch.stack((r, g, b))
    
def adjust_hue2(img, hue_factor):
    """Adjust hue of an image.

    The image hue is adjusted by converting the image to HSV and
    cyclically shifting the intensities in the hue channel (H).
    The image is then converted back to original image mode.

    `hue_factor` is the amount of shift in H channel and must be in the
    interval `[-0.5, 0.5]`.

    See `Hue`_ for more details.

    .. _Hue: https://en.wikipedia.org/wiki/Hue

    Args:
        img (PIL Image or torch.Tensor): Image to be adjusted.
                                         Input can be PIL or torch.Tensor.
        hue_factor (float):  How much to shift the hue channel. Should be in
            [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
            HSV space in positive and negative direction respectively.
            0 means no shift. Therefore, both -0.5 and 0.5 will give an image
            with complementary colors while 0 gives the original image.

    Returns:
        PIL Image or torch.Tensor: Hue adjusted image.
        If input is PIL Image, return PIL Image
        If input is torch.Tensor, return torch.Tensor
    """
    if not(-0.5 <= hue_factor <= 0.5):
        raise ValueError('hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
    
    if not _is_pil_image(img) and type(img) is not torch.Tensor:
        raise TypeError('img should be PIL Image or torch.Tensor. Got {}'.format(type(img)))
        
    if _is_pil_image(img):

        input_mode = img.mode
        if input_mode in {'L', '1', 'I', 'F'}:
            return img

        h, s, v = img.convert('HSV').split()
#         print(h, s, v)
#         print(transform(h))
        np_h = np.array(h, dtype=np.uint8)
        # uint8 addition take cares of rotation across boundaries
        with np.errstate(over='ignore'):
            np_h += np.uint8(hue_factor * 255)
        h = Image.fromarray(np_h, 'L')
        img = Image.merge('HSV', (h, s, v)).convert(input_mode)
        return img
    else:
        assert type(img) is torch.Tensor
        assert len(img.shape) == 3 # input img must be 3D torch.Tensor
        assert img.shape[0] == 3 # input img must have 3 channels 
        # the default ToTensor in torchvision scale the RGB from [0,255] to [0, 1]
        img = _rgb2hsv(img)
        h, s, v = img[0], img[1], img[2]
        new_h = h * 255
        new_h = new_h.type(torch.IntTensor)
        new_h += int(hue_factor * 255)
        new_h = new_h.type(torch.FloatTensor)
        new_h = new_h / 255.0
        new_img = _hsv2rgb(torch.stack((new_h, s, v)))
        return new_img

In [557]:
# open a new image
img_cat = Image.open("/Users/xni/Documents/pytorch_task/cat.jpg")


In [558]:
# use the old adjust_hue 
# when input is PIL image, we use the old adjust_hue
img_new = adjust_hue2(img_cat, 0.5)


In [559]:
# show the image adjusted using old method
img_new.show()

In [560]:
# transform the image into torch.Tensor
transform = transforms.ToTensor()
img_cat_tensor = transform(img_cat)
img_cat_tensor.shape

torch.Size([3, 559, 838])

In [561]:
# use the new adjust_hue
# when input is torch.Tensor, we use the new adjust_hue
new_cat = adjust_hue2(img_cat_tensor, 0.5)

In [562]:
# transform the image back into PIL image
transform2 = transforms.ToPILImage()
img_cal_pil = transform2(new_cat)

In [563]:
# show the image adjusted using new method

img_cal_pil.show()