<a href="https://colab.research.google.com/github/pegahsalehi/Stain-to_Stain-Translation/blob/master/Reinhard.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
pip install spams

Collecting spams
[?25l  Downloading https://files.pythonhosted.org/packages/0d/09/ae296293c992e6ec792952827b66ca97cee375af53585c85a1e6d4a54f02/spams-2.6.1.tar.gz (1.9MB)
[K     |████████████████████████████████| 1.9MB 2.7MB/s 
Building wheels for collected packages: spams
  Building wheel for spams (setup.py) ... [?25l[?25hdone
  Created wheel for spams: filename=spams-2.6.1-cp36-cp36m-linux_x86_64.whl size=4512970 sha256=8aeead5983a41274760a20b76a6e194fa167c482f34c5ac5b6106cead630663a
  Stored in directory: /root/.cache/pip/wheels/76/a6/a8/5959872693a82d5497a91aee3665bb1676cee33304d86c1495
Successfully built spams
Installing collected packages: spams
Successfully installed spams-2.6.1


In [0]:
import cv2 as cv
import numpy as np

### Some functions ###

def lab_split(I):
    """
    Convert from RGB uint8 to LAB and split into channels
    :param I: uint8
    :return:
    """
    I = cv.cvtColor(I, cv.COLOR_RGB2LAB)
    I = I.astype(np.float32)
    I1, I2, I3 = cv.split(I)
    I1 /= 2.55
    I2 -= 128.0
    I3 -= 128.0
    return I1, I2, I3

def merge_back(I1, I2, I3):
    """
    Take seperate LAB channels and merge back to give RGB uint8
    :param I1:
    :param I2:
    :param I3:
    :return:
    """
    I1 *= 2.55
    I2 += 128.0
    I3 += 128.0
    I = np.clip(cv.merge((I1, I2, I3)), 0, 255).astype(np.uint8)
    return cv.cvtColor(I, cv.COLOR_LAB2RGB)

def get_mean_std(I):
    """
    Get mean and standard deviation of each channel
    :param I: uint8
    :return:
    """
    I1, I2, I3 = lab_split(I)
    m1, sd1 = cv.meanStdDev(I1)
    m2, sd2 = cv.meanStdDev(I2)
    m3, sd3 = cv.meanStdDev(I3)
    means = m1, m2, m3
    stds = sd1, sd2, sd3
    return means, stds

### Main class ###

class normalizer(object):
    """
    A stain normalization object
    """
    def __init__(self):
        self.target_means = None
        self.target_stds = None

    def fit(self, target):
        target = standardize_brightness(target)
        means, stds = get_mean_std(target)
        self.target_means = means
        self.target_stds = stds

    def transform(self, I):
        I = standardize_brightness(I)
        I1, I2, I3 = lab_split(I)
        means, stds = get_mean_std(I)
        norm1 = ((I1 - means[0]) * (self.target_stds[0] / stds[0])) + self.target_means[0]
        norm2 = ((I2 - means[1]) * (self.target_stds[1] / stds[1])) + self.target_means[1]
        norm3 = ((I3 - means[2]) * (self.target_stds[2] / stds[2])) + self.target_means[2]
        return merge_back(norm1, norm2, norm3)

In [0]:
import numpy as np
import cv2 as cv
import spams
import matplotlib.pyplot as plt

def read_image(path):
    """
    Read an image to RGB uint8
    :param path:
    :return:
    """
    im = cv.imread(path)
    im = cv.cvtColor(im, cv.COLOR_BGR2RGB)
    return im

def standardize_brightness(I):
    """
    :param I:
    :return:
    """
    p = np.percentile(I, 90)
    return np.clip(I * 255.0 / p, 0, 255).astype(np.uint8)


In [0]:
import time, errno, cv2, os
from glob import glob
import numpy as np
import matplotlib.pyplot as plt
import datetime

exp_name = ""
media_url = "drive/My Drive/patches"
path_h = "drive/My Drive/"

i1 = read_image(path_h + "7.tiff")

def assure_path_exists(path):
    dir = os.path.dirname(path)
    if not os.path.exists(dir):
        try:
            os.makedirs(dir)
        except OSError as e:
            if e.errno != errno.EEXIS:
                raise

def transform_imgs(dir, output_dir, normalizer=''):
    pattern = "*.tiff"
    images = []
    # read directory of images
    for _, _, _ in os.walk(dir):
        images.extend(glob(os.path.join(dir, pattern)))
    # print(len(images))
    images.sort()
    # images = images[0:2]

    for counter, img in enumerate(images):
        img_name = (img.split('/')[-1])
        transformed_img = normalizer.transform(read_image(img))
        cv2.imwrite(output_dir + str(img_name), cv2.cvtColor(transformed_img, cv2.COLOR_RGB2BGR))

# start working here
print(exp_name)
input_dir = media_url + "/Unorm/"
output_dir = media_url + "/Reinhard/"
assure_path_exists(output_dir)
print(output_dir)

start_time = time.time()
start_time_p = datetime.datetime.now()

n = normalizer()
n.fit(i1)
transform_imgs(input_dir, output_dir, n)

elapsed = (time.time() - start_time)
elapsed_time_P = datetime.datetime.now() - start_time_p

print("--- %s seconds ---" % round((elapsed / 2), 2),'\n')
print ('time: ',elapsed_time_P)


drive/My Drive/thesis/pic/K/
--- 14.16 seconds --- 

time:  0:00:28.329338
