In [1]:
from glob import glob
import os
from PIL import Image
import numpy as np
from scipy.sparse import csr_matrix, eye
import cv2
import math
from p_tqdm import p_umap
import subprocess
from util import recreate_dir, load_config

In [2]:
DATA_DIR = "data/20220324_kocinka_rybna"
CFG = load_config(f"{DATA_DIR}/config.py").DEVIGNETTING

In [3]:
TIF_RAW_DIR = f"{DATA_DIR}/tif_raw"
assert os.path.exists(TIF_RAW_DIR), f"Directory {TIF_RAW_DIR} does not exist, run 1_conversion.ipynb first"
TIF_DEVIGNETTE_DIR = f"{DATA_DIR}/tif_devignette"

In [4]:
assert os.path.exists(TIF_RAW_DIR)
recreate_dir(TIF_DEVIGNETTE_DIR)
file_paths = glob(f"{TIF_RAW_DIR}/*.tif")

In [5]:
def devignette(file_path):
    def sparse(i, j, v, m, n):
        return csr_matrix((v, (i, j)), shape=(m, n))

    def get_C(sz):
        cent_x = round((sz[1]-1)/2)
        cent_y = round((sz[0]-1)/2)
        X, Y = np.meshgrid(np.arange(-cent_x, sz[1]-cent_x), np.arange(-cent_y, sz[0]-cent_y))
        R = np.round(np.sqrt(X**2 + Y**2))
        max_R = int(np.max(R))
        num_R = max_R+1
        R_ind = R
        num_pixels = sz[0] * sz[1]
        allones = np.ones(num_pixels)
        C = csr_matrix((allones, (np.arange(num_pixels), R_ind.flatten())), shape=(num_pixels, num_R))
        return C

    def W_vec2sparse(W):
        num_pixels = W.size
        ii = np.arange(num_pixels)
        jj = np.arange(num_pixels)
        ss = W.flatten()
        #sW = sparse(ii,jj,ss,num_pixels,num_pixels)
        sW = csr_matrix((ss, (ii, jj)), shape=(num_pixels, num_pixels))
        return sW

    def get_Lyy(sz):
        num_pixels = np.prod(sz)
        used_pixels = sz-2
        nnz = used_pixels * 3
        inds = np.reshape(np.arange(num_pixels), sz, order='F')
        inds_latter = np.copy(inds)
        inds_latter = inds_latter[:-2]
        inds_middle = np.copy(inds)
        inds_middle = inds_middle[1:-1]
        inds_former = np.copy(inds)
        inds_former = inds_former[2:]
        ii = np.zeros(nnz, np.int32)
        jj = np.zeros(nnz, np.int32)
        ss = np.zeros(nnz, np.int32)
        ii[:used_pixels] = inds_middle.flatten()
        jj[:used_pixels] = inds_middle.flatten()
        ss[:used_pixels] = -2
        ii[used_pixels:used_pixels*2] = inds_middle.flatten()
        jj[used_pixels:used_pixels*2] = inds_latter.flatten()
        ss[used_pixels:used_pixels*2] = 1
        ii[used_pixels*2:] = inds_middle.flatten()
        jj[used_pixels*2:] = inds_former.flatten()
        ss[used_pixels*2:] = 1
        Lxx = sparse(ii,jj,ss,num_pixels,num_pixels)
        return Lxx

    lambda_ = 0.25
    itr_num = 4
    alpha_ = 0.6
    epsilon_ = 0.000001  # perturbation on B
    dsfact = 0.25
    
    im_given = np.array(Image.open(file_path))
    im_given_sampled = cv2.resize(im_given, dsize=(0,0), fx=dsfact, fy=dsfact)
    
    im_data = im_given_sampled
    sz = im_data.shape
    num_pixels = sz[0] * sz[1]
    min_val = np.min(im_data)
    max_val = np.max(im_data)
    shift = 273.15#-min_val + 1.
    
    # Z_shift calculation
    im_data_shifted = im_data + shift
    Z_shift = np.log(im_data_shifted)
    
    # Initialize W
    num_pixels = np.prod(im_data.shape)
    vector_W = np.ones(num_pixels)
    W = W_vec2sparse(vector_W)
    
    # Calculate radial gradient
    rg = np.zeros(sz)
    for j in range(1, sz[0]):
        for i in range(1, sz[1]):
            cx = int(i - (sz[1]-1) * 0.5)
            cy = int(j - (sz[0]-1) * 0.5)
            dx = Z_shift[j, i] - Z_shift[j, i - 1]
            dy = Z_shift[j, i] - Z_shift[j - 1, i]
            cx = float(cx)
            cy = float(cy)
            rg_value = (cx * dx + cy * dy) / np.sqrt(cx * cx + cy * cy + epsilon_)
            rg[j, i] = rg_value
            
    # Get C matrix
    C = get_C(sz)
    num_R = C.shape[1]
    
    # Compute Lvi
    R = np.zeros(sz, dtype=np.int32)
    for j in range(sz[0]):
        for i in range(sz[1]):
            cx = round((sz[1]-1)/2)
            cy = round((sz[0]-1)/2)
            R[j, i] = int(math.sqrt(float(cx * cx + cy * cy)))
    R = R.reshape(-1, 1)
    A = np.zeros((R.shape[0], num_R))
    for i in range(R.shape[0]):
        if 1 < R[i, 0] < num_R + 1:
            A[i, R[i, 0]] = 1
            A[i, R[i, 0] - 1] = -1
    Lxxnyy = get_Lyy(num_R)
    
    # Iterative reweighted least squares
    my_I = eye(num_R, num_R)
    epsilon = 0.03
    Gamma1 = epsilon * my_I
    Gamma2 = lambda_ * (2 * num_pixels / num_R) * Lxxnyy
    
    for k in range(itr_num):
        # print(f"Iteration {k+1}/{itr_num}")
        right = W @ rg.flatten()
        B_r = np.linalg.solve(A.T @ W @ W @ A + Gamma1.T @ Gamma1 + Gamma2.T @ Gamma2, A.T @ W @ right)
        B = C @ B_r
        bg = np.zeros((sz[0], sz[1]))
        for j in range(1, sz[0]):
            for i in range(1, sz[1]):
                cx = round((sz[1]-1)/2)
                cy = round((sz[0]-1)/2)
                radius = int(np.sqrt(cx ** 2 + cy ** 2))
                radius = max(2, min(num_R, radius))
                bg[j, i] = B_r[radius] - B_r[radius - 1]
        S1 = np.abs(bg - rg)
        with np.errstate(divide='ignore'):
            S2 = alpha_ * (S1) ** (alpha_ - 1)
        vector_W = np.exp(-S1) * (1 - np.exp(-S2))

        W = W_vec2sparse(vector_W)
        b = np.exp(B - np.max(B))
        X = np.log(im_data_shifted) - np.log(np.reshape(b, (sz[0], sz[1])))
    
    # Correction of full size image
    bias = cv2.resize(np.reshape(b, (sz[0], sz[1])), dsize=(im_given.shape[1], im_given.shape[0]), interpolation=cv2.INTER_CUBIC)
    im_corrected=(im_given+shift)/bias-shift
    
    # Correction of offset calculated on center region
    shape = im_given.shape
    cx = shape[1]//2
    cy = shape[0]//2
    w = int(shape[1]*0.5)
    w2 = w//2
    h = int(shape[0]*0.5)
    h2 = h//2
    offset = im_given[cy-h2:cy+h2, cx-w2:cx+w2] - im_corrected[cy-h2:cy+h2, cx-w2:cx+w2]
    im_corrected = im_corrected + np.mean(offset)
    
    # Saving corrected tif
    im = Image.fromarray(im_corrected)
    file_name = file_path.split('/')[-1]
    im.save(f"{TIF_DEVIGNETTE_DIR}/{file_name}")
    subprocess.check_call(['exiftool', '-tagsfromfile', f"{TIF_RAW_DIR}/{file_name}", f"{TIF_DEVIGNETTE_DIR}/{file_name}", '-overwrite_original_in_place'], stdout=subprocess.DEVNULL, stderr=subprocess.STDOUT)

In [6]:
results = p_umap(devignette, file_paths)

  0%|          | 0/453 [00:00<?, ?it/s]