# Fast and High Quality Highlight Removal from A Single Image 

In [None]:
%matplotlib inline
import cv2
import numpy as np
from matplotlib import pyplot as plt
import scipy
import scipy.cluster.hierarchy as sch
from scipy.cluster.vq import vq,kmeans,kmeans2,whiten

from math import sqrt

def easyshow(src,title,ifcolorbar=True,cmap='gray'):
    gci=plt.imshow(src,cmap=cmap)
    plt.title(title)
    if ifcolorbar:
        plt.colorbar(gci)
    plt.show()
    return
src_path = '/mnt/MEDIA/Projects/anti_reflection/cubic.jpg'

src_img = cv2.imread(src_path)
src_img = cv2.resize(src_img,None,fx=0.25, fy=0.25, interpolation = cv2.INTER_CUBIC)
src_img = cv2.medianBlur(src_img,3)

plt.imshow(cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
plt.show()

srcimg_shape=src_img.shape

## get L2 chromatic for material clustering

In [None]:
eps = 1e-5
l2norm_img = np.sqrt(np.sum(np.square(src_img.astype(np.float32)),axis=2,keepdims=False))
easyshow(l2norm_img,'l2-norm',True)

l2normed_img = np.divide(src_img,eps + np.expand_dims(l2norm_img,axis=2))
# assumpt that illumination is constantly white 
ill_l2vec = np.tile(sqrt(1.0/3.0),(3,))
ill_l2img = np.tile(sqrt(1.0/3.0),srcimg_shape)

l2normed_pixels = l2normed_img.reshape((srcimg_shape[0] * srcimg_shape[1] ,srcimg_shape[2]))
parrall_coe = np.dot(l2normed_pixels,ill_l2vec.reshape(3,1)).clip(0,1)
orthogonal_coe = np.sqrt(1-np.square(parrall_coe))
easyshow(parrall_coe.reshape(srcimg_shape[:2]),'parrall_coe',True)
easyshow(orthogonal_coe.reshape(srcimg_shape[:2]),'orthogonal_coe',True)

gamma_orth_ini = l2normed_img - np.expand_dims(parrall_coe.reshape(srcimg_shape[:2])*sqrt(1.0/3.0),axis=2)
gamma_orth_src = np.stack([gamma_orth_ini[:,:,i]/(eps + orthogonal_coe.reshape(srcimg_shape[:2])) for i in range(3)],axis=2)


## material clustering 

In [None]:
import time
k_ini = 1
k_max = 15
out_circle_thr = 0.02
out_circle_ratio_thr = 0.02
def get_random_color():
    return np.random.uniform(0,255,(3,)).astype(np.uint8)
kmeans_res_4show = np.zeros_like(gamma_orth_src).astype(np.uint8)
colors = [get_random_color() for i in range(k_max)]

l2chromatic_pixels = gamma_orth_src.reshape((srcimg_shape[0] * srcimg_shape[1],srcimg_shape[2]))

for k in range(k_ini,k_max):
    tic = time.time()
    while True:
        try:
            print 'tring to do %d-means clustering'%k
            _tic = time.time()
            centors,labels = kmeans2(l2chromatic_pixels, k, iter=50, thresh=1e-05, minit='points', missing='raise', check_finite=True)
            _tac = time.time()
            print '%d-means clustering cost %.3fs'%(k,_tac-_tic)
        except all:
            print 'error occur during kmeans'
        else:
            break
    
    # reproject to get coes, then compare total fitting error
    centors_l2chromatic_pixels = np.array(map(lambda x:centors[x].tolist(),labels.tolist()))
    # project to gamma_orth, gamma_par
    gamma_orth = np.sum(centors_l2chromatic_pixels * l2normed_pixels,axis=1,keepdims=False)
    gamma_par = np.sum(l2normed_pixels * np.sqrt(1.0/3.0),axis=1,keepdims=False)
    
    easyshow(abs(np.square(gamma_orth) + np.square(gamma_par) - 1.0).reshape(l2norm_img.shape),'fit error',True)

#     easyshow(gamma_orth.reshape(l2norm_img.shape),'gamma_orth',True)    
#     easyshow(gamma_par.reshape(l2norm_img.shape),'gamma_par',True)
    
# show kmeans result 
    
    for i in range(l2chromatic_pixels.shape[0]):
        row = i // srcimg_shape[1]
        col = i % srcimg_shape[1]
        kmeans_res_4show[row,col,:] = colors[labels[i]]
    gci=plt.imshow(cv2.cvtColor((0.7*src_img.astype(kmeans_res_4show.dtype) + 0.3*kmeans_res_4show).astype(np.uint8), cv2.COLOR_BGR2RGB))
    gci=plt.imshow(cv2.cvtColor(kmeans_res_4show.astype(np.uint8), cv2.COLOR_BGR2RGB))

    plt.title('kmeans-%d'%k)
    plt.show()
    # end of show kmeans result
    tf_err_num =np.count_nonzero(abs(np.square(gamma_orth) + np.square(gamma_par) - 1.0) > out_circle_thr)
    ratio = float(tf_err_num) / float(gamma_orth.shape[0])
    print '%d/%d pixels are not "near" circle, ratio is %f'%(tf_err_num,gamma_orth.shape[0],ratio)
    tac = time.time()
    print '%d-means clustering and analysis cost %fs'%(k,tac-tic)
    if ratio < out_circle_ratio_thr:
        break
        
print '%d custers found'%k

## Diffuse Component Recovrering

In [None]:
# find pure diffuse pixels for each custer
histbins = 64
no_noise_ratio = 1.0/10.0
ab_clusters = []
for kc in range(k):
    gamma_pixs = parrall_coe[labels==kc].ravel()
    gamma_hist,binids = np.histogram(a=gamma_pixs+1e-6* np.random.random(gamma_pixs.shape), bins=histbins)
    accum = gamma_hist[0]
    accum_thr = int(len(gamma_pixs) * no_noise_ratio)
    metratio_flag=False
    selected_binid = histbins-1
    # show hist
    bins_mean = map(lambda x,y:(x+y)/2.0,binids[:histbins],binids[1:])
    plt.stem(bins_mean, gamma_hist, linefmt='b-', markerfmt='bo', basefmt='r-')
    plt.show()
    for binid in range(1,histbins-1):
        if not metratio_flag:
            accum += gamma_hist[binid]
            if accum > accum_thr:
                metratio_flag = True
        if metratio_flag:
            if gamma_hist[binid-1] <gamma_hist[binid] and gamma_hist[binid] > gamma_hist[binid+1]:
                selected_binid = binid
                break                
    # log            
    print 'selected bin id is %d, average coe is %f'%(selected_binid,(binids[selected_binid]+binids[selected_binid+1])/2.0)
    b = (binids[selected_binid]+binids[selected_binid+1]) / 2.0
    a = sqrt(1-b**2)    
    ab_clusters.append((a,b))    
print ab_clusters

In [None]:
# 
reverted_lambda = np.zeros_like(gamma_orth_src)
for i in range(l2chromatic_pixels.shape[0]):
    row = i // srcimg_shape[1]
    col = i % srcimg_shape[1]
    a = ab_clusters[labels[i]][0]
    b = ab_clusters[labels[i]][1]
    
    alpha = np.dot(gamma_orth_src[row,col,:],src_img[row,col,:]) / (a+eps)
    
    reverted_lambda[row,col,:] = alpha *(a * gamma_orth_src[row, col,:] + b * ill_l2vec) 
    

In [None]:
# print reverted_lambda
easyshow(reverted_lambda[:,:,0],'reverted_lambda-0',True)
easyshow(reverted_lambda[:,:,1],'reverted_lambda-1',True)
easyshow(reverted_lambda[:,:,2],'reverted_lambda-2',True)


In [None]:
img = reverted_lambda.clip(0,255).astype(np.uint8)
img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)

plt.imshow(cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
plt.show()
plt.imshow(img)
plt.show()