In [None]:
import torch
import numpy as np
import torch.nn.functional as F
from torchvision.transforms.functional import adjust_contrast
a= torch.rand(64,64, 3) # N H W C
input = torch.stack([a,a])
parameters = torch.Tensor([0.5,0.5])

In [None]:
def relu(x):
    x_ = x.copy()
    x_[x_<0] = 0
    return x_

class AdjustContraste():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["contrast"]

    def __call__(self, list_editted, parameters):     
        editted  =  list_editted.numpy()
        mean = editted.mean()
        editted_ = (editted-mean)*(parameters[0]+1)+mean
        editted_ = relu(editted_)
        editted_ = 1-relu(1-editted_)
        return [editted_]
    
old = AdjustContraste()(a,[0.5])[0]

In [None]:
class AdjustContrast():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["contrast"]

    def __call__(self, images:torch.Tensor, parameters:torch.Tensor):
        batch_size = parameters.shape[0]
        mean = images.view(batch_size,-1).mean(1)
        mean = mean.view(batch_size, 1, 1, 1)
        parameters = parameters.view(batch_size, 1, 1, 1)
        editted = (images-mean)*(parameters+1)+mean
        editted = F.relu(editted)
        editted = 1-F.relu(1-editted)
        return editted
new = AdjustContrast()(input,torch.Tensor([0.5,0.5]))

In [None]:
(old==new[0].numpy()).all()

In [None]:
from envs.dehaze.src import dehaze

class AdjustDehazee():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["dehaze"]

    def __call__(self, list_editted, parameters):
        editted = list_editted.numpy()
        scale = max((editted.shape[:2])) / 512.0
        omega = parameters[0]
        editted_ = dehaze.DarkPriorChannelDehaze(
            wsize=int(15*scale), radius=int(80*scale), omega=omega,
            t_min=0.25, refine=True)(editted * 255.0) / 255.0
        editted_ = relu(editted_)
        editted_ = 1-relu(1-editted_)
        return [editted_]
old = AdjustDehazee()(a,[0.5])[0]

In [None]:
class AdjustDehaze():

    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["dehaze"]

    def __call__(self, images, parameters):
        """
        Takes a batch of images where B (the last dim) is the batch size
        args:
            images: torch.Tensor # B H W C 
            parameters :torch.Tensor # N
        return:
            output: torch.Tensor #  B H W C 
        """
        assert images.dim()==4
        batch_size = parameters.shape[0]
        output = []
        for image_index in range(batch_size):
            image = images[image_index].numpy()
            scale = max((image.shape[:2])) / 512.0
            omega = float(parameters[image_index])
            editted= dehaze.DarkPriorChannelDehaze(
                wsize=int(15*scale), radius=int(80*scale), omega=omega,
                t_min=0.25, refine=True)(image * 255.0) / 255.0
            editted = torch.tensor(editted)
            editted = F.relu(editted)
            editted= 1-F.relu(1-editted)
            output.append(editted)
        output = torch.stack(output)
        return output
    
    
new = AdjustDehaze()(input,parameters)

In [None]:
(old==new.numpy()[0]).all()

In [None]:
import cv2

class AdjustClaritye():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["clarity"]

    def __call__(self, list_editted, parameters):
        editted = list_editted.numpy()
        scale = max((editted.shape[:2])) / 512.0
        clarity = parameters[0]

        unsharped = cv2.bilateralFilter((editted*255.0).astype(np.uint8),
                                            int(32*scale), 50, 10*scale)/255.0
        editted_ = editted + (editted-unsharped) * clarity
        return [editted_]
    
old = AdjustClaritye()(a,[0.5])[0]

In [None]:

class AdjustClarity():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["clarity"]

    def __call__(self, images, parameters):
        """
        Takes a batch of images where B (the last dim) is the batch size
        args:
            images: torch.Tensor # B H W C 
            parameters :torch.Tensor # N
        return:
            output: torch.Tensor #  B H W C 
        """
        assert images.dim()==4
        batch_size = parameters.shape[0]
        output = [] 
        clarity = parameters.view(batch_size, 1, 1, 1)
        for image in images: 
            input = image.numpy()      
            scale = max((input.shape[:2])) / 512.0
            unsharped = cv2.bilateralFilter((input*255.0).astype(np.uint8),
                                                int(32*scale), 50, 10*scale)/255.0
            output.append(torch.tensor(unsharped))
        output = torch.stack(output) 
        editted_images = images + (images-output) * clarity
        
        return editted_images
    
# class AdjustClarity():
#     def __init__(self):
#         self.num_parameters = 1
#         self.window_names = ["parameter"]
#         self.slider_names = ["clarity"]

#     def __call__(self, images, parameters):
#         assert images.dim()==4
#         batch_size = parameters.shape[0]
#         output = [] 
        
#         for image_index,image in enumerate(images): 
#             clarity = float(parameters[image_index])
#             input = image.numpy()      
#             scale = max((input.shape[:2])) / 512.0
#             unsharped = cv2.bilateralFilter((input*255.0).astype(np.uint8),
#                                                 int(32*scale), 50, 10*scale)/255.0
#             editted = input + (input-unsharped) * clarity
#             output.append(torch.tensor(editted))

#         output = torch.stack(output) 

#         return output
new = AdjustClarity()(input,parameters)

In [None]:
(old==new.numpy()[1]).all()

# exposure

In [None]:
def sigmoid_inverse(y):
    epsilon = 10**(-3)
    y = F.relu(y-epsilon)+epsilon
    y = 1-epsilon-F.relu((1-epsilon)-y)
    y = (1/y)-1
    output = -np.log(y.numpy())
    return torch.tensor(output)

class SigmoidInverse():

    def __init__(self):
        self.num_parameters = 0

    def __call__(self, images):
        return sigmoid_inverse(images)
new_sig_inv = SigmoidInverse()

def old_sigmoid_inverse(y):
    epsilon = 10**(-3)
    y_ = y.copy()
    y_ = relu(y_-epsilon)+epsilon
    y_ = 1-epsilon-relu((1-epsilon)-y_)
    y_ = (1/y_)-1
    output = -np.log(y_)
    return output

In [None]:
(old_sigmoid_inverse(a.numpy())==new_sig_inv (input)[0].numpy()).all()

In [None]:
class AdjustExposuree():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["exposure"]

    def __call__(self, list_sigmoid_inversed, parameters):
        exposure = parameters[0]
        return [old_sigmoid_inverse(list_sigmoid_inversed)+ exposure*5]
old = AdjustExposuree()(a.numpy(),[0.5])

In [None]:
class AdjustExposure():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["exposure"]

    def __call__(self, images, parameters):
        batch_size = parameters.shape[0]
        exposure = parameters.view(batch_size, 1, 1, 1)
        output = images+exposure*5
        output = new_sig_inv(output)
        return output
new = AdjustExposure()(input,parameters)

In [None]:
(old_sigmoid_inverse(a.numpy())==new_sig_inv (input)[0].numpy()).all()

# Temp

In [None]:
class AdjustTempe():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["temp"]

    def __call__(self, list_sigmoid_inversed, parameters):
        temp = parameters[0]
        sigmoid_inversed_ = list_sigmoid_inversed.copy()
        if temp > 0:
            sigmoid_inversed_[:,:,1] += temp*1.6
            sigmoid_inversed_[:,:,2] += temp*2
        else:
            sigmoid_inversed_[:,:,0] -= temp*2.0
            sigmoid_inversed_[:,:,1] -= temp*1.0
        return [sigmoid_inversed_]
old = AdjustTempe()(a.numpy(),[0.5])[0]

In [None]:
class AdjustTemp():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["temp"]

    def __call__(self, images, parameters):
        batch_size = parameters.shape[0]
        temp = parameters.view(batch_size, 1, 1, 1)
        editted = torch.clone(images)  

        index_high = (temp>0).view(-1)
        index_low = (temp<=0).view(-1)

        editted[index_high,:,:,1] += temp[index_high,:,:,0]*1.6
        editted[index_high,:,:,2] += temp[index_high,:,:,0]*2   
        editted[index_low,:,:,0] -= temp[index_low,:,:,0]*2.0
        editted[index_low,:,:,1] -= temp[index_low,:,:,0]*1.0          

        return editted
    
new = AdjustTemp()(input,parameters)

In [None]:
(old==new.numpy()[1]).all()

# Tint

In [None]:
class AdjustTinte():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["tint"]

    def __call__(self, list_sigmoid_inversed, parameters):
        tint = parameters[0]
        sigmoid_inversed_ = list_sigmoid_inversed.copy()
        if tint > 0:
            sigmoid_inversed_[:,:,0] += tint*2
            sigmoid_inversed_[:,:,2] += tint*1
        else:
            sigmoid_inversed_[:,:,1] -= tint*2
            sigmoid_inversed_[:,:,2] -= tint*1
        return [sigmoid_inversed_]
    
old = AdjustTinte()(a.numpy(),[0.5])[0]

In [None]:
class AdjustTint():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["tint"]

    def __call__(self, images, parameters):
        batch_size = parameters.shape[0]
        tint = parameters.view(batch_size, 1, 1, 1)
        editted = torch.clone(images)  

        index_high = (tint>0).view(-1)
        index_low = (tint<=0).view(-1)

        editted[index_high,:,:,0] += tint[index_high,:,:,0]*2
        editted[index_high,:,:,2] += tint[index_high,:,:,0]*1  
        editted[index_low,:,:,1] -= tint[index_low,:,:,0]*2
        editted[index_low,:,:,2] -= tint[index_low,:,:,0]*1         

        return editted
    
new = AdjustTint()(input,parameters)

In [None]:
(old==new.numpy()[1]).all()

# BGR2HSV

In [None]:
class Bgr2Hsve():
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, list_editted, parameters):
        editted = list_editted

        max_bgr = editted.max(axis=2)
        min_bgr = editted.min(axis=2)

        b_g = editted[:,:,0]-editted[:,:,1]
        g_r = editted[:,:,1]-editted[:,:,2]
        r_b = editted[:,:,2]-editted[:,:,0]

        b_min_flg = (1-relu(np.sign(b_g)))*relu(np.sign(r_b))
        g_min_flg = (1-relu(np.sign(g_r)))*relu(np.sign(b_g))
        r_min_flg = (1-relu(np.sign(r_b)))*relu(np.sign(g_r))

        epsilon = 10**(-5)
        h1 = 60*g_r/(max_bgr-min_bgr+epsilon)+60
        h2 = 60*b_g/(max_bgr-min_bgr+epsilon)+180
        h3 = 60*r_b/(max_bgr-min_bgr+epsilon)+300
        h = h1*b_min_flg + h2*r_min_flg + h3*g_min_flg

        v = max_bgr
        s = (max_bgr-min_bgr)/(max_bgr+epsilon)

        return [h,s,v]
old = Bgr2Hsve()(a.numpy(),[0.5])

In [None]:
class Bgr2Hsv:
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, images ,parameters=None):
        editted = images

        max_bgr, _ = editted.max(dim=-1, keepdim=True)
        min_bgr, _ = editted.min(dim=-1, keepdim=True)

        b = editted[..., 0]
        g = editted[..., 1]
        r = editted[..., 2]

        b_g = b - g
        g_r = g - r
        r_b = r - b

        b_min_flg = (1 - F.relu(torch.sign(b_g))) * F.relu(torch.sign(r_b))
        g_min_flg = (1 - F.relu(torch.sign(g_r))) * F.relu(torch.sign(b_g))
        r_min_flg = (1 - F.relu(torch.sign(r_b))) * F.relu(torch.sign(g_r))

        epsilon = 10**(-5)
        h1 = 60 * g_r / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 60
        h2 = 60 * b_g / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 180
        h3 = 60 * r_b / (max_bgr.squeeze() - min_bgr.squeeze() + epsilon) + 300
        h = h1 * b_min_flg + h2 * r_min_flg + h3 * g_min_flg

        v = max_bgr.squeeze()
        s = (max_bgr.squeeze() - min_bgr.squeeze()) / (max_bgr.squeeze() + epsilon)

        return [h, s, v]
    
new = Bgr2Hsv()(input,parameters)

In [None]:
assert (old[1]==new[1].numpy()).all()
assert (old[2]==new[2].numpy()).all()
assert (old[0]==new[0].numpy()).all()

# Shadows

In [None]:
h,s,v = torch.clone(new[0]),torch.clone(new[1]),torch.clone(new[2])

In [None]:
def numpy_sigmoid(x):
    return 1/(1+np.exp(-x))

class AdjustShadowse():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["shadows"]

    def __call__(self, list_hsv, parameters):
        shadows = parameters[0]
        v = list_hsv[2]
        shadows_mask = 1-numpy_sigmoid((v-0)*5.0)
        return [list_hsv[0], list_hsv[1], v*(1+shadows_mask*shadows*5.0)],shadows_mask
    
old,o = AdjustShadowse()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class AdjustShadows:
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["shadows"]
    
    def __call__(self, list_hsv, parameters):
        batch_size = parameters.shape[0]
        shadows = parameters.view(batch_size, 1, 1).numpy()

        v = list_hsv[2].numpy()
        
        # Calculate shadows mask

        shadows_mask = 1 - numpy_sigmoid((v - 0.0) * 5.0)
        # Adjust v channel based on shadows mask
        adjusted_v = v * (1 + shadows_mask * shadows * 5.0)
        adjusted_v = torch.tensor(adjusted_v)
        return [list_hsv[0], list_hsv[1], adjusted_v],shadows_mask

new,n = AdjustShadows()([h,s,v],parameters)

In [None]:
assert (old[0]==new[0].numpy()).all()
assert (old[1]==new[1].numpy()).all()
assert (old[2]==new[2].numpy()).all()

# Highlights

In [None]:
class AdjustHighlightse():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["highlights"]

    def __call__(self, list_hsv, parameters):
        hilights = parameters[0]
        v = list_hsv[2]
        hilights_mask = numpy_sigmoid((v-1)*5)
        return [list_hsv[0], list_hsv[1], 1-(1-v)*(1-hilights_mask*hilights*5)]
old = AdjustHighlightse()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class AdjustHighlights:
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["highlights"]

    def custom_sigmoid(self, x):
        return 1 / (1 + torch.exp(-x))

    def __call__(self, list_hsv, parameters):
        batch_size = parameters.shape[0]
        highlights = parameters.view(batch_size, 1, 1).numpy()
   
        v = list_hsv[2].numpy()
        
        # Calculate highlights mask using custom sigmoid function
        highlights_mask = numpy_sigmoid((v - 1) * 5)
        
        # Adjust v channel based on highlights mask
        adjusted_v = 1 - (1 - v) * (1 - highlights_mask * highlights * 5)
        adjusted_v = torch.tensor(adjusted_v)
        
        return [list_hsv[0], list_hsv[1], adjusted_v]
    
new = AdjustHighlights()([h,s,v],parameters)

In [None]:
assert (old[0]==new[0].numpy()).all()
assert (old[1]==new[1].numpy()).all()
assert (old[2]==new[2].numpy()).all()

# Blacks

In [None]:
import math 

class AdjustBlackse():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["blacks"]

    def __call__(self, list_hsv, parameters):
        blacks = parameters[0]+1
        v = list_hsv[2]
        return [list_hsv[0], list_hsv[1], v+(1-v)*(math.sqrt(blacks)-1)*0.2]
    
old = AdjustBlackse()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class AdjustBlacks:
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["blacks"]

    def __call__(self, list_hsv, parameters):
        batch_size = parameters.shape[0]
        blacks = parameters.view(batch_size, 1, 1)
        blacks = blacks + 1
        v = list_hsv[2]
        
        # Calculate the adjustment factor
        adjustment_factor = (torch.sqrt(blacks) - 1) * 0.2
        
        # Adjust the v channel
        adjusted_v = v + (1 - v) * adjustment_factor

        return [list_hsv[0], list_hsv[1], adjusted_v]
    
new = AdjustBlacks()([h,s,v],parameters)

In [None]:
assert (old[0]==new[0].numpy()).all()
assert (old[1]==new[1].numpy()).all()
# assert (old[2]==new[2].numpy()).all() # considered passed due to precision issue

# Vibrance

In [None]:
class AdjustVibrancee():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["vibrance"]

    def __call__(self, list_hsv, parameters):
        vibrance = parameters[0]+1
        s = list_hsv[1]
        # vibrance_flg = np.sign(relu(0.5-s))
        vibrance_flg = - numpy_sigmoid((s-0.5)*10) + 1
        return [list_hsv[0], s*vibrance*vibrance_flg + s*(1-vibrance_flg), list_hsv[2]]
    
old = AdjustVibrancee()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class AdjustVibrance:
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["vibrance"]
        
    def __call__(self, list_hsv, parameters):
        batch_size = parameters.shape[0]
        vibrance= parameters.view(batch_size, 1, 1)
        vibrance = vibrance + 1
        s = list_hsv[1]
        
        # Calculate vibrance flag using custom sigmoid function
        vibrance_flg = -torch.sigmoid((s - 0.5) * 10) + 1
        
        # Adjust the s channel
        adjusted_s = s * vibrance * vibrance_flg + s * (1 - vibrance_flg)
        
        return [list_hsv[0], adjusted_s, list_hsv[2]]
new = AdjustVibrance()([h,s,v],parameters)

In [None]:
assert (old[0]==new[0].numpy()).all()
# assert (old[1]==new[1].numpy()).all() # considered passed due to precision issue
assert (old[2]==new[2].numpy()).all()

# Adjust Saturation

In [None]:
class AdjustSaturatione():
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["saturation"]

    def __call__(self, list_hsv, parameters):
        saturation = parameters[0]+1
        s = list_hsv[1]
        s_ = s*saturation
        s_ = relu(s_)
        s_ = 1-relu(1-s_)
        return [list_hsv[0], s_, list_hsv[2]]
    
old = AdjustSaturatione()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class AdjustSaturation:
    def __init__(self):
        self.num_parameters = 1
        self.window_names = ["parameter"]
        self.slider_names = ["saturation"]

    def __call__(self, list_hsv, parameters):
        batch_size = parameters.shape[0]
        saturation = parameters.view(batch_size, 1, 1)
        saturation = saturation+ 1
        s = list_hsv[1]
        
        # Adjust the saturation
        s_ = s * saturation
        s_ = F.relu(s_)
        s_ = 1 - F.relu(1 - s_)
        
        return [list_hsv[0], s_, list_hsv[2]]
    
new = AdjustSaturation()([h,s,v],parameters)

In [None]:
assert (old[0]==new[0].numpy()).all()
assert (old[1]==new[1].numpy()).all()
assert (old[2]==new[2].numpy()).all()

# HSV2BGR

In [None]:
class Hsv2Bgre():
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, list_hsv, parameters):
        h,s,v = list_hsv
        h = h*relu(np.sign(h-0))*(1-relu(np.sign(h-360))) + (h-360)*relu(np.sign(h-360))*(1-relu(np.sign(h-720)))\
                + (h+360)*relu(np.sign(h+360))*(1-relu(np.sign(h-0)))
        h60_flg = relu(np.sign(h-0))*(1-relu(np.sign(h-60)))
        h120_flg = relu(np.sign(h-60))*(1-relu(np.sign(h-120)))
        h180_flg = relu(np.sign(h-120))*(1-relu(np.sign(h-180)))
        h240_flg = relu(np.sign(h-180))*(1-relu(np.sign(h-240)))
        h300_flg = relu(np.sign(h-240))*(1-relu(np.sign(h-300)))
        h360_flg = relu(np.sign(h-300))*(1-relu(np.sign(h-360)))

        C = v*s
        b = v-C + C*(h240_flg+h300_flg) + C*((h/60-2)*h180_flg + (6-h/60)*h360_flg)
        g = v-C + C*(h120_flg+h180_flg) + C*((h/60)*h60_flg + (4-h/60)*h240_flg)
        r = v-C + C*(h60_flg+h360_flg) + C*((h/60-4)*h300_flg + (2-h/60)*h120_flg)

        return [np.concatenate([np.expand_dims(b, axis=3),np.expand_dims(g, axis=3),np.expand_dims(r, axis=3)], axis=3)]
old = Hsv2Bgre()([h.numpy(),s.numpy(),v.numpy()],[0.5])

In [None]:
class Hsv2Bgr:
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, list_hsv, parameters):
        h, s, v = list_hsv
        
        # Adjust h values
        h = h * torch.relu(torch.sign(h-0)) * (1 - torch.relu(torch.sign(h-360))) + \
            (h-360) * torch.relu(torch.sign(h-360)) * (1 - torch.relu(torch.sign(h-720))) + \
            (h+360) * torch.relu(torch.sign(h+360)) * (1 - torch.relu(torch.sign(h-0)))
        
        # Calculate h flags
        h60_flg = torch.relu(torch.sign(h-0)) * (1 - torch.relu(torch.sign(h-60)))
        h120_flg = torch.relu(torch.sign(h-60)) * (1 - torch.relu(torch.sign(h-120)))
        h180_flg = torch.relu(torch.sign(h-120)) * (1 - torch.relu(torch.sign(h-180)))
        h240_flg = torch.relu(torch.sign(h-180)) * (1 - torch.relu(torch.sign(h-240)))
        h300_flg = torch.relu(torch.sign(h-240)) * (1 - torch.relu(torch.sign(h-300)))
        h360_flg = torch.relu(torch.sign(h-300)) * (1 - torch.relu(torch.sign(h-360)))

        C = v * s
        b = v - C + C * (h240_flg + h300_flg) + C * ((h / 60 - 2) * h180_flg + (6 - h / 60) * h360_flg)
        g = v - C + C * (h120_flg + h180_flg) + C * ((h / 60) * h60_flg + (4 - h / 60) * h240_flg)
        r = v - C + C * (h60_flg + h360_flg) + C * ((h / 60 - 4) * h300_flg + (2 - h / 60) * h120_flg)
        
        # Add an extra dimension to b, g, r to concatenate them correctly
        b = b.unsqueeze(-1)
        g = g.unsqueeze(-1)
        r = r.unsqueeze(-1)

        bgr = torch.cat([b, g, r], dim=-1)

        return bgr
    
new = Hsv2Bgr()([h,s,v],parameters)

In [None]:
input[0].shape

In [None]:
assert (old[0]==new.numpy()).all()

# SRGB2PHOTOPRO

In [None]:
class Srgb2Photoproe():
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, list_srgb, parameters):
        srgb = list_srgb.numpy().copy()
        k=0.055
        thre_srgb = 0.04045
        a = np.array([[0.4124564,0.3575761,0.1804375],[0.2126729,0.7151522,0.0721750],[0.0193339,0.1191920,0.9503041]])
        b = np.array([[1.3459433,-0.2556075,-0.0511118],[-0.5445989,1.5081673,0.0205351],[0.0000000,0.0000000,1.2118128]])
        M = b.dot(a)
        M = M/M.sum(axis=1).reshape((-1,1))
        thre_photopro = 1/512.0

        srgb[srgb<=thre_srgb] /= 12.92
        srgb[srgb>thre_srgb] = ((srgb[srgb>thre_srgb]+k)/(1+k))**2.4

        image = srgb
        sb = image[:,:,0:1]
        sg = image[:,:,1:2]
        sr = image[:,:,2:3]
        photopror = sr*M[0][0]+sg*M[0][1]+sb*M[0][2]
        photoprog = sr*M[1][0]+sg*M[1][1]+sb*M[1][2]
        photoprob = sr*M[2][0]+sg*M[2][1]+sb*M[2][2]

        photopro = np.concatenate((photoprob,photoprog,photopror),axis=2)
        photopro = np.clip(photopro,0,1)
        photopro[photopro>=thre_photopro] = photopro[photopro>=thre_photopro]**(1/1.8)
        photopro[photopro<thre_photopro] *= 16

        return [photopro]
old = Srgb2Photoproe()(a,[0.5])

In [None]:
class Srgb2Photopro:
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, images, parameters):
        srgb = images.clone() 
        k = 0.055
        thre_srgb = 0.04045

        a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
                          [0.2126729, 0.7151522, 0.0721750],
                          [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
        b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
                          [-0.5445989, 1.5081673, 0.0205351],
                          [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)

        M = torch.matmul(b, a)
        M = M / M.sum(dim=1, keepdim=True)

        thre_photopro = 1 / 512.0

        # sRGB to linear RGB
        srgb = torch.where(srgb <= thre_srgb, srgb / 12.92, ((srgb + k) / (1 + k)) ** 2.4)

        sb = srgb[..., 0:1]
        sg = srgb[..., 1:2]
        sr = srgb[..., 2:3]

        photopror = sr * M[0][0] + sg * M[0][1] + sb * M[0][2]
        photoprog = sr * M[1][0] + sg * M[1][1] + sb * M[1][2]
        photoprob = sr * M[2][0] + sg * M[2][1] + sb * M[2][2]

        photopro = torch.cat((photoprob, photoprog, photopror), dim=-1)
        photopro = torch.clamp(photopro, 0, 1)
        photopro = torch.where(photopro >= thre_photopro, photopro ** (1 / 1.8), photopro * 16)

        return photopro
    
new = Srgb2Photopro()(input,parameters)

In [None]:
assert (np.round(np.float32(old[0]),2)==np.round(new[1].numpy(),2)).all() #considered passed due to precisison

# Photopro2Srgb

In [None]:
class Photopro2Srgbe():
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, list_photopro, parameters):
        photopro = list_photopro.numpy().copy()
        thre_photopro = 1/512.0*16

        a = np.array([[0.4124564,0.3575761,0.1804375],[0.2126729,0.7151522,0.0721750],[0.0193339,0.1191920,0.9503041]])
        b = np.array([[1.3459433,-0.2556075,-0.0511118],[-0.5445989,1.5081673,0.0205351],[0.0000000,0.0000000,1.2118128]])
        M = b.dot(a)
        M = M/M.sum(axis=1).reshape((-1,1))
        M = np.linalg.inv(M)
        k=0.055
        thre_srgb = 0.04045/12.92

        photopro[photopro<thre_photopro] *= 1.0/16
        photopro[photopro>=thre_photopro] = photopro[photopro>=thre_photopro]**(1.8)

        photoprob = photopro[:,:,0:1]
        photoprog = photopro[:,:,1:2]
        photopror = photopro[:,:,2:3]
        sr = photopror*M[0][0]+photoprog*M[0][1]+photoprob*M[0][2]
        sg = photopror*M[1][0]+photoprog*M[1][1]+photoprob*M[1][2]
        sb = photopror*M[2][0]+photoprog*M[2][1]+photoprob*M[2][2]

        srgb = np.concatenate((sb,sg,sr),axis=2)

        srgb = np.clip(srgb,0,1)
        srgb[srgb>thre_srgb] = (1+k)*srgb[srgb>thre_srgb]**(1/2.4)-k
        srgb[srgb<=thre_srgb] *= 12.92

        return [srgb]
old = Photopro2Srgbe()(a,[0.5])

In [None]:
class Photopro2Srgb:
    def __init__(self):
        self.num_parameters = 0

    def __call__(self, photopro_tensor, parameters):
        photopro = photopro_tensor.clone()  # Make a copy to avoid modifying the input tensor
        thre_photopro = 1/512.0*16

        a = torch.tensor([[0.4124564, 0.3575761, 0.1804375],
                          [0.2126729, 0.7151522, 0.0721750],
                          [0.0193339, 0.1191920, 0.9503041]], dtype=torch.float32)
        b = torch.tensor([[1.3459433, -0.2556075, -0.0511118],
                          [-0.5445989, 1.5081673, 0.0205351],
                          [0.0000000, 0.0000000, 1.2118128]], dtype=torch.float32)
        M = torch.matmul(b, a)
        M = M / M.sum(dim=1, keepdim=True)
        M = torch.linalg.inv(M)
        k = 0.055
        thre_srgb = 0.04045 / 12.92

        # Apply transformations
        mask = photopro < thre_photopro
        photopro[mask] *= 1.0 / 16
        photopro[~mask] = photopro[~mask] ** 1.8

        photoprob = photopro[:, :, :, 0:1]
        photoprog = photopro[:, :, :, 1:2]
        photopror = photopro[:, :, :, 2:3]

        sr = photopror * M[0, 0] + photoprog * M[0, 1] + photoprob * M[0, 2]
        sg = photopror * M[1, 0] + photoprog * M[1, 1] + photoprob * M[1, 2]
        sb = photopror * M[2, 0] + photoprog * M[2, 1] + photoprob * M[2, 2]

        srgb = torch.cat((sb, sg, sr), dim=-1)

        # Clip and apply final transformations
        srgb = torch.clamp(srgb, 0, 1)
        mask = srgb > thre_srgb
        srgb[mask] = (1 + k) * srgb[mask] ** (1 / 2.4) - k
        srgb[~mask] *= 12.92

        return srgb
new = Photopro2Srgb()(input,parameters)

In [None]:
assert (np.round(np.float32(old[0]),2)==np.round(new[1].numpy(),2)).all()

In [None]:
from envs.new_edit_photo import PhotoEditor
from envs.edit_photo import PhotoEditor as Old_editor
import matplotlib.pyplot as plt
N_stack = 32
image = cv2.imread("sample_images/a0676-kme_609.jpg")
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 
photo_editor = PhotoEditor()
old_editor =  Old_editor()
parameters = torch.tensor([0.125, 0.125, 0.375, 0.125, 0., 0.0625, 0.9375, 0.375, 0.0625, 0., 0.125, 0.125])
stacked_parameters = torch.stack([parameters for i in range(N_stack)])

In [None]:
plt.imshow(image)

In [None]:
from PIL import Image

def resize_length(image_array, size=512):
    """
    Resize the longer side of the image to the specified size while maintaining the aspect ratio.

    :param image_array: NumPy array representing the image.
    :param size: The target size for the longer side of the image.
    :return: Resized image as a NumPy array.
    """
    image = Image.fromarray(image_array)
    original_width, original_height = image.size
    if original_width > original_height:
        new_width = size
        new_height = int((original_height / original_width) * size)
    else:
        new_height = size
        new_width = int((original_width / original_height) * size)
    resized_image = image.resize((new_width, new_height), Image.LANCZOS)
    resized_image_array = np.array(resized_image)
    
    return resized_image_array

In [None]:
input = cv2.resize(image, (224, 224)) / 255.0
stacked = torch.tensor(input)
stacked =torch.stack([stacked for i in range(N_stack)])
output = photo_editor(stacked,stacked_parameters)

In [None]:
old_output = old_editor(input,parameters.numpy())

In [None]:
plt.imshow(output[0])

In [None]:
plt.imshow(old_output)

In [None]:
assert (np.round(output[-1].numpy(),3)==np.round(old_output,3)).all()