In [10]:
import numpy as np
from skimage import io,color
import cv2 as cv
from scipy.signal import convolve2d

In [4]:
def make_correction(channel:np.ndarray,mu:float)->np.ndarray:
    channel=channel/255.0
    channel_mean=channel.mean()
    channel_var=channel.var()
    channel_max=channel_mean+mu*channel_var
    channel_min=channel_mean-mu*channel_var
    channel_corrected=255.0*(channel-channel_min)/(channel_max-channel_min)
    channel_corrected=np.clip(channel_corrected,0,255)
    return channel_corrected.astype(np.uint8)

In [3]:
def shrink(x:any,eps:float)->float:
    ans=x*np.max(np.abs(x)-eps,0)/np.linalg.norm(x,ord=2)
    return ans

In [5]:
def get_corrected_image(img:np.ndarray,mu:float)->np.ndarray:
    R,G,B=img[:,:,0],img[:,:,1],img[:,:,2]
    R_corrected,G_corrected,B_corrected=make_correction(R,mu),make_correction(G,mu),make_correction(B,mu)
    corrected_image=np.stack([R_corrected,G_corrected,B_corrected],axis=-1)
    return corrected_image


In [6]:
def RGB2LAB(img:np.ndarray)->tuple[np.ndarray]:
    lab_img=color.rgb2lab(img)
    L,A,B=lab_img[:,:,0],lab_img[:,:,1],lab_img[:,:,2]
    return (L,A,B)

In [8]:
def initialize_RI(L_channel:np.ndarray,sigma:float=2)->tuple[np.ndarray]:
    height,width=L_channel.shape[0],L_channel.shape[1]
    R=np.zeros(shape=(height,width))
    I0=cv.GaussianBlur(L_channel,(0,0),sigmaX=sigma,sigmaY=sigma)
    return (R,I0)

In [9]:
SOBEL_KERNEL_X=np.array([[-1, 0, 1],[-2, 0, 2],[-1, 0, 1]])
SOBEL_KERNEL_Y=np.array([[-1, -2, -1],[0, 0, 0],[1, 2, 1]])

In [11]:
def gradient_x(img:np.ndarray)->np.ndarray:
    sobel_x=np.abs(convolve2d(img,SOBEL_KERNEL_X,mode='same'))
    return np.uint8(sobel_x)

In [12]:
def gradient_y(img:np.ndarray)->np.ndarray:
    sobel_y=np.abs(convolve2d(img,SOBEL_KERNEL_Y,mode='same'))
    return np.uint8(sobel_y)

In [13]:
# delta denotes lambda in original paper we can not use lambda as variable because it is a reserved keyword in python language
def optimizer(R:np.ndarray,I:np.ndarray,I0:np.ndarray,L:np.ndarray,d:np.ndarray,alpha:float=100,beta:float=0.1,gamma:float=1,delta:float=10)->float:
    penalty1=(np.linalg.norm(R*I-L,ord=2))**2
    penalty2=(alpha*(np.linalg.norm(np.vstack((gradient_x(I),gradient_y(I))),ord=2))**2)
    penalty3=(delta*(np.linalg.norm(np.vstack((gradient_x(R),gradient_y(R)))-d,ord=2))**2)
    penalty3=beta*(penalty3+np.linalg.norm(d,ord=1))
    penalty4=gamma*(np.linalg.norm(I-I0,ord=2))**2
    return (penalty1+penalty2+penalty3+penalty4)

In [14]:
def update_d(R:np.ndarray,delta:float=10)->np.ndarray:
    d_x=shrink(gradient_x(R),1.0/2*delta)
    d_y=shrink(gradient_y(R),1.0/2*delta)
    return np.vstack((d_x,d_y))

In [15]:
def denom(D_x:np.ndarray=SOBEL_KERNEL_X,D_y:np.ndarray=SOBEL_KERNEL_Y)->np.ndarray:
    fft_d_x=np.fft.fft(D_x)
    conjugate_fft_d_x=np.conjugate(fft_d_x)
    fft_d_y=np.fft.fft(D_y)
    conjugate_fft_d_y=np.conjugate(fft_d_y)
    return fft_d_x*conjugate_fft_d_x+fft_d_y*conjugate_fft_d_y

In [16]:
def update_R(L:np.ndarray,I:np.ndarray,beta:float=0.1,delta:float=10)->np.ndarray:
    r=(1+b*delta)*np.fft.fft(L/I)
    r=r/(np.fft.fft(1)+beta*delta*denom())
    return np.fft.ifft(r)
    

In [17]:
def update_I(I0:np.ndarray,L:np.ndarray,R:np.ndarray,alpha:float=100,gamma:float=1)->np.ndarray:
    r=np.fft.fft(gamma*I0+L/R)
    r=r/(np.fft.fft(1+gamma)+alpha*denom())
    return np.fft.ifft(r)

In [18]:
def enhance_R(R:np.ndarray)->np.ndarray:
    clahe=cv.createCLAHE(clipLimit=2,tileGridSize=(8,8))
    clahe_result=clahe.apply(R)
    return clahe_result