In [6]:
import cv2
import os
from joblib import Parallel, delayed
from tqdm import tqdm
import re
import numpy as np
import scipy
import skimage
import ants
# downscale = 4
# img_path = "/storage/valis_reg/BFIW_Block/msrcr"
# img_paths = [os.path.join(img_path, img) for img in os.listdir(img_path)]
# save_paths = [os.path.join("/storage/valis_reg/BFIW_Block_low_res/msrcr", img) for img in os.listdir(img_path)]


# def store_downscaled_image(img_path, save_path):
#     img = cv2.imread(img_path)
#     img = cv2.resize(img, (img.shape[1]//downscale, img.shape[0]//downscale))
#     cv2.imwrite(save_path, img)

# _ = Parallel(n_jobs=16)(delayed(store_downscaled_image)(img_path, save_path) for img_path, save_path in tqdm(zip(img_paths, save_paths)))

In [5]:
import numpy as np
import cv2

def get_ksize(sigma):
    # opencv calculates ksize from sigma as
    # sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8
    # then ksize from sigma is
    # ksize = ((sigma - 0.8)/0.15) + 2.0

    return int(((sigma - 0.8)/0.15) + 2.0)

def get_gaussian_blur(img, ksize=0, sigma=5):
    # if ksize == 0, then compute ksize from sigma
    if ksize == 0:
        ksize = get_ksize(sigma)

    # Gaussian 2D-kernel can be seperable into 2-orthogonal vectors
    # then compute full kernel by taking outer product or simply mul(V, V.T)
    sep_k = cv2.getGaussianKernel(ksize, sigma)

    # if ksize >= 11, then convolution is computed by applying fourier transform
    return cv2.filter2D(img, -1, np.outer(sep_k, sep_k))

def ssr(img, sigma):
    # Single-scale retinex of an image
    # SSR(x, y) = log(I(x, y)) - log(I(x, y)*F(x, y))
    # F = surrounding function, here Gaussian

    return np.log10(img) - np.log10(get_gaussian_blur(img, ksize=0, sigma=sigma) + 1.0)

def msr(img, sigma_scales=[15, 80, 250]):
    # Multi-scale retinex of an image
    # MSR(x,y) = sum(weight[i]*SSR(x,y, scale[i])), i = {1..n} scales

    msr = np.zeros(img.shape)
    # for each sigma scale compute SSR
    for sigma in sigma_scales:
        msr += ssr(img, sigma)

    # divide MSR by weights of each scale
    # here we use equal weights
    msr = msr / len(sigma_scales)

    # computed MSR could be in range [-k, +l], k and l could be any real value
    # so normalize the MSR image values in range [0, 255]
    msr = cv2.normalize(msr, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC3)

    return msr

def color_balance(img, low_per, high_per):
    '''Contrast stretch img by histogram equilization with black and white cap'''

    tot_pix = img.shape[1] * img.shape[0]
    # no.of pixels to black-out and white-out
    low_count = tot_pix * low_per / 100
    high_count = tot_pix * (100 - high_per) / 100

    # channels of image
    ch_list = []
    if len(img.shape) == 2:
        ch_list = [img]
    else:
        ch_list = cv2.split(img)

    cs_img = []
    # for each channel, apply contrast-stretch
    for i in range(len(ch_list)):
        ch = ch_list[i]
        # cummulative histogram sum of channel
        cum_hist_sum = np.cumsum(cv2.calcHist([ch], [0], None, [256], (0, 256)))

        # find indices for blacking and whiting out pixels
        li, hi = np.searchsorted(cum_hist_sum, (low_count, high_count))
        if (li == hi):
            cs_img.append(ch)
            continue
        # lut with min-max normalization for [0-255] bins
        lut = np.array([0 if i < li
                        else (255 if i > hi else round((i - li) / (hi - li) * 255))
                        for i in np.arange(0, 256)], dtype = 'uint8')
        # constrast-stretch channel
        cs_ch = cv2.LUT(ch, lut)
        cs_img.append(cs_ch)

    if len(cs_img) == 1:
        return np.squeeze(cs_img)
    elif len(cs_img) > 1:
        return cv2.merge(cs_img)
    return None

def msrcr(img, sigma_scales=[15, 80, 250], alpha=125, beta=46, G=192, b=-30, low_per=1, high_per=3):
    # Multi-scale retinex with Color Restoration
    # MSRCR(x,y) = G * [MSR(x,y)*CRF(x,y) - b], G=gain and b=offset
    # CRF(x,y) = beta*[log(alpha*I(x,y) - log(I'(x,y))]
    # I'(x,y) = sum(Ic(x,y)), c={0...k-1}, k=no.of channels

    img = img.astype(np.float64) + 1.0
    # Multi-scale retinex and don't normalize the output
    msr_img = msr(img, sigma_scales)
    # Color-restoration function
    crf = beta * (np.log10(alpha * img) - np.log10(np.sum(img, axis=2, keepdims=True)))
    # MSRCR
    msrcr_ = G * (msr_img*crf - b)
    # normalize MSRCR
    msrcr_ = cv2.normalize(msrcr_, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8UC3) # type: ignore
    # color balance the final MSRCR to flat the histogram distribution with tails on both sides
    msrcr_ = color_balance(msrcr_, low_per, high_per)

    return msrcr_

In [None]:
class BFIWSlide:
    def __init__(self, slide_path, key=None, is_ref=False):
        self.slide_path = slide_path
        self.key = key
        self.img = cv2.imread(slide_path)
        self.img = cv2.cvtColor(self.img, cv2.COLOR_BGR2RGB)
        self.msr_img = None
        self.msr_img_gray = None
        self.mask = None
        self.is_ref = is_ref
        self.apply_msrcr()
        self.get_mask()
        self.apply_mask(self.mask)

    def apply_msrcr(self):
        self.msr_img = msrcr(self.img)
        self.msr_img_gray = cv2.cvtColor(self.msr_img, cv2.COLOR_RGB2GRAY)

    def apply_mask(self, mask):
        self.temp_img = (np.ones_like(self.img) * 255).astype(np.uint8)
        self.temp_img[mask == 1] = self.img[mask == 1]
        self.img = self.temp_img
        self.temp_img = (np.ones_like(self.img) * 255).astype(np.uint8)
        self.temp_img[mask == 1] = self.msr_img[mask == 1]
        self.msr_img = self.temp_img

    def get_mask(self):
        if self.msr_img_gray is None:
            self.apply_msrcr()
        sample_ants = ants.from_numpy(self.msr_img_gray)
        self.mask = sample_ants.get_mask(cleanup=4).numpy().astype(np.uint8) # type: ignore
        # return self.mask

In [None]:
class BFIWReg:
    def __init__(self, src_dir, dest_dir, ref_idx) -> None:
        self.src_dir = src_dir
        imgs = os.listdir(self.src_dir)
        regex = re.compile(r".*-SE_(\d+)_original.jpg")
        imgs = sorted(imgs, key=lambda x: int(regex.match(x).group(1)))
        imgs_ordered = {}
        for img in imgs:
            section_num = int(regex.match(img).group(1))
            section_id = str(section_num)
            section_id_digits = len(section_id)
            if section_id_digits <4:
                section_id = '0'*(4-section_id_digits) + str(section_num)
            imgs_ordered[section_id] = os.path.join(self.src_dir, img)
        self.imgs = imgs_ordered
        self.ref_slide = BFIWSlide(self.imgs[ref_idx], key=ref_idx, is_ref=True)
        self.slides = {key: BFIWSlide(val, key=key) for key, val in self.imgs.items() if key != ref_idx}
        self.slides[ref_idx] = self.ref_slide