In [1]:
import os
import traceback
import sys
import pprint
import numpy as np
import cv2
from tqdm import tqdm
import glob
import tensorflow as tf
import configparser
from PIL import Image
import tensorflow.keras.backend as K
from tensorflow.keras.losses import binary_crossentropy, BinaryCrossentropy


import os
import sys
import datetime
import shutil
import sys
import glob
import traceback
import random
import numpy as np
import cv2
import tensorflow as tf
import tensorflow.keras.backend as K
from PIL import Image, ImageFilter, ImageOps
from tensorflow.keras.layers import Lambda
from tensorflow.keras.layers import Input

from tensorflow.keras.layers import Conv2D, Dropout, Conv2D, MaxPool2D, BatchNormalization

from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import concatenate
from tensorflow.keras.activations import relu
from tensorflow.keras import Model
from tensorflow.keras.losses import  BinaryCrossentropy
from tensorflow.keras.metrics import BinaryAccuracy
#from tensorflow.keras.metrics import Mean
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

MODEL = "model"
TRAIN = "train"
EVAL = "eval"
MASK = "mask"
INFER  = "infer"
TILEDINFER = "tiledinfer"
BEST_MODEL_FILE = "best_model.h5"

os.environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"
os.environ["TF_ENABLE_GPU_GARBAGE_COLLECTION"]="false"

In [2]:
class ConfigParser:

    def __init__(self, config_path):
        print("==== ConfigParser {}".format(config_path))
        if not os.path.exists(config_path):
            raise Exception("Not found config_path {}".format(config_path))

        try:
            self.parse(config_path)
            self.dump_all()
        except Exception as ex:
            print("==== ConfigParser Exception -----------------------{}".format(ex))
            traceback.print_exc()

    def parse(self, config_path):
        config = configparser.ConfigParser()
        config.read(config_path)
        self.dict = {s: {i[0]: i[1] for i in config.items(s)} for s in config.sections()}

    def dump_all(self):
        pprint.pprint(self.dict)

    def get(self, section, name, dvalue=None):
        value = None
        try:
            value = self.dict[section][name]
            value = eval(value)
        except:
            value = dvalue
            print("=== WARNING: Not found [{}]  {}, return default value {}".format(section, name, value))
        return value

In [3]:
class ImageMaskDataset:

    def __init__(self, config_file):
        config = ConfigParser(config_file)
        self.image_width = config.get(MODEL, "image_width")
        self.image_height = config.get(MODEL, "image_height")
        self.image_channels = config.get(MODEL, "image_channels")
        self.train_dataset = [config.get(TRAIN, "image_datapath"),
                              config.get(TRAIN, "mask_datapath")]

        self.eval_dataset = [config.get(EVAL, "image_datapath"),
                             config.get(EVAL, "mask_datapath")]

        self.binarize = config.get(MASK, "binarize")
        self.threshold = config.get(MASK, "threshold")
        self.blur_mask = config.get(MASK, "blur")

        # Fixed blur_size
        self.blur_size = (3, 3)

    def create(self, dataset=TRAIN, debug=False):
        if dataset not in [TRAIN, EVAL]:
            raise Exception("Invalid dataset")
        image_datapath = None
        mask_datapath = None

        [image_datapath, mask_datapath] = self.train_dataset
        if dataset == EVAL:
            [image_datapath, mask_datapath] = self.eval_dataset

        image_files = glob.glob(image_datapath + "/*.jpg")
        image_files += glob.glob(image_datapath + "/*.png")
        image_files += glob.glob(image_datapath + "/*.bmp")
        image_files += glob.glob(image_datapath + "/*.tif")
        image_files = sorted(image_files)

        mask_files = None
        if os.path.exists(mask_datapath):
            mask_files = glob.glob(mask_datapath + "/*.jpg")
            mask_files += glob.glob(mask_datapath + "/*.png")
            mask_files += glob.glob(mask_datapath + "/*.bmp")
            mask_files += glob.glob(mask_datapath + "/*.tif")
            mask_files = sorted(mask_files)

            if len(image_files) != len(mask_files):
                raise Exception("FATAL: Images and masks unmatched")

        num_images = len(image_files)
        if num_images == 0:
            raise Exception("FATAL: Not found image files")

        X = np.zeros((num_images, self.image_height, self.image_width, self.image_channels), dtype=np.uint8)
        Y = np.zeros((num_images, self.image_height, self.image_width, 1), dtype=bool)

        for n, image_file in tqdm(enumerate(image_files), total=len(image_files)):
            image = cv2.imread(image_file)
            image = cv2.resize(image, dsize=(self.image_height, self.image_width), interpolation=cv2.INTER_NEAREST)
            X[n] = image

            if mask_files is not None:
                mask = cv2.imread(mask_files[n])
                mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
                mask = cv2.resize(mask, dsize=(self.image_height, self.image_width), interpolation=cv2.INTER_NEAREST)

                # Binarize mask
                if self.binarize:
                    mask[mask < self.threshold] = 0
                    mask[mask >= self.threshold] = 255

                # Blur mask
                if self.blur_mask:
                    mask = cv2.blur(mask, self.blur_size)

                mask = np.expand_dims(mask, axis=-1)
                Y[n] = mask

#                 if debug:
#                     cv2.imshow("---", mask)
#                     cv2.waitKey(27)
#                     input("XX")

        return X, Y

In [4]:
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1 - y_true) * (1 - y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1 - y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

def dice_coef(y_true, y_pred):
    smooth = 1.0
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    y_true_f = K.cast(y_true_f, 'float32')
    y_pred_f = K.cast(y_pred_f, 'float32')

    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    score = (2.0 * intersection + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)
    return score


def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)
   
def bce_dice_loss(y_true, y_pred):
    loss = binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    return loss / 2.0

def jacard_similarity(y_true, y_pred):
    """
    Intersection-Over-Union (IoU), also known as the Jaccard Index
    """
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    y_true_f = K.cast(y_true_f, 'float32')
    y_pred_f = K.cast(y_pred_f, 'float32')
    intersection = K.sum(y_true_f * y_pred_f)
    union = K.sum((y_true_f + y_pred_f) - (y_true_f * y_pred_f))
    return intersection / union


def jacard_loss(y_true, y_pred):
    """
    Intersection-Over-Union (IoU), also known as the Jaccard loss
    """
    return 1 - jacard_similarity(y_true, y_pred)

def iou_coef(y_true, y_pred):
    return jacard_similarity(y_true, y_pred)

def iou_loss(y_true, y_pred):
    return 1 - jacard_similarity(y_true, y_pred)

def ssim_loss(y_true, y_pred):
    """
    Structural Similarity Index (SSIM) loss
    """
    y_true_f = K.cast(y_true, 'float32')
    y_pred_f = K.cast(y_pred, 'float32')

    return 1 - tf.image.ssim(y_true_f, y_pred_f, max_val=1)

def basnet_hybrid_loss(y_true, y_pred):
    """
    Hybrid loss proposed in BASNET (https://arxiv.org/pdf/2101.04704.pdf)
    The hybrid loss is a combination of the binary cross entropy, structural similarity
    and intersection-over-union losses, which guide the network to learn
    three-level (i.e., pixel-, patch- and map- level) hierarchy representations.
    """
    bce_loss = BinaryCrossentropy(from_logits=False)
    bce_loss = bce_loss(y_true, y_pred)

    ms_ssim_loss = ssim_loss(y_true, y_pred)
    jac_loss     = jacard_loss(y_true, y_pred)

    loss = bce_loss + ms_ssim_loss + jac_loss
    return loss/3.0

def bce_iou_loss(y_true, y_pred):
    bce_loss = BinaryCrossentropy(from_logits=False)
    loss1 = bce_loss(y_true, y_pred)
    loss2 = iou_loss(y_true, y_pred)

    loss = loss1 + loss2
    return loss/2.0

In [5]:
class EpochChangeCallback(tf.keras.callbacks.Callback):

    def __init__(self, eval_dir, metrics=["accuracy", "val_accuracy"]):
        self.eval_dir = eval_dir
        self.metrics = metrics

        if os.path.exists(self.eval_dir):
            shutil.rmtree(self.eval_dir)

        if not os.path.exists(self.eval_dir):
            os.makedirs(self.eval_dir)

        self.train_losses_file = os.path.join(self.eval_dir, "train_losses.csv")
        self.train_accuracies_file = os.path.join(self.eval_dir, "train_metrics.csv")

        try:
            if not os.path.exists(self.train_losses_file):
                with open(self.train_losses_file, "w") as f:
                    header = "epoch, loss, val_loss\n"
                    f.write(header)
        except Exception as ex:
            traceback.print_exc()

        try:
            if not os.path.exists(self.train_accuracies_file):
                with open(self.train_accuracies_file, "w") as f:
                    header = "epoch," + metrics[0] + "," + metrics[1] + "," + "\n"
                    f.write(header)
        except Exception as ex:
            traceback.print_exc()

    def on_epoch_end(self, epoch, logs):
        acc = logs.get(self.metrics[0], logs.get('acc', 0))
        val_acc = logs.get(self.metrics[1], logs.get('val_acc', 0))
        loss = logs.get('loss', 0)
        val_loss = logs.get('val_loss', 0)

        NL = "\n"

        try:
            with open(self.train_losses_file, "a") as f:
                losses = "{}, {:.4f}, {:.4f}".format(epoch, loss, val_loss)
                f.write(losses + NL)
        except Exception as ex:
            traceback.print_exc()

        try:
            with open(self.train_accuracies_file, "a") as f:
                accuracies = "{}, {:.4f}, {:.4f}".format(epoch, acc, val_acc)
                f.write(accuracies + NL)
        except Exception as ex:
            traceback.print_exc()

In [6]:
class GrayScaleImageWriter:

    def __init__(self, image_format=".jpg"):
        self.image_format = image_format

    def save(self, data, output_dir, name, factor=255.0):
        h, w = data.shape[:2]
        image = Image.new("L", (w, h))
        print(f" image w: {w} h: {h}")
        for i in range(w):
            for j in range(h):
                z = data[j][i][0] if isinstance(data[j][i], list) else data[j][i]
                v = int(z * factor)
                image.putpixel((i, j), v)

        image_filepath = os.path.join(output_dir, name + self.image_format)
        image.save(image_filepath)
        print(f"=== Saved {image_filepath}")

    def save_resized(self, data, resized, output_dir, name, factor=255.0):
        h, w = data.shape[:2]
        image = Image.new("L", (w, h))
        print(f" image w: {w} h: {h}")
        for i in range(w):
            for j in range(h):
                z = data[j][i][0] if isinstance(data[j][i], list) else data[j][i]
                v = int(z * factor)
                image.putpixel((i, j), v)

        image_filepath = os.path.join(output_dir, name + self.image_format)
        print(f"== resized to {resized}")
        image = image.resize(resized)
        image.save(image_filepath)
        image = image.convert("RGB")
        print(f"=== Saved {image_filepath}")
        return np.array(image)

In [7]:
class TensorflowUNet:
    
    def __init__(self, config_file):
        
        self.set_seed()
        self.config_file = config_file
        self.config    = ConfigParser(config_file)
        image_height   = self.config.get(MODEL, "image_height")
        image_width    = self.config.get(MODEL, "image_width")
        image_channels = self.config.get(MODEL, "image_channels")
        num_classes    = self.config.get(MODEL, "num_classes")
        base_filters   = self.config.get(MODEL, "base_filters")
        num_layers     = self.config.get(MODEL, "num_layers")
        self.model     = self.create(num_classes, image_height, image_width, image_channels, 
                                base_filters = base_filters, num_layers = num_layers)
        
        learning_rate  = self.config.get(MODEL, "learning_rate")
        clipvalue      = self.config.get(MODEL, "clipvalue", 0.2)

        self.optimizer = Adam(learning_rate = learning_rate, 
             beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, 
             clipvalue=clipvalue,  #2023/0626
             amsgrad=False)
        print("=== Optimizer Adam learning_rate {} clipvalue {}".format(learning_rate, clipvalue))

        self.model_loaded = False

        # 2023/05/20 Modified to read loss and metrics from train_eval_infer.config file.
        binary_crossentropy = tf.keras.metrics.binary_crossentropy
        binary_accuracy     = tf.keras.metrics.binary_accuracy

        # Default loss and metrics functions
        self.loss    = binary_crossentropy
        self.metrics = [binary_accuracy]

        # Read a loss function name from our config file, and eval it.
        # loss = "binary_crossentropy"
        self.loss  = eval(self.config.get(MODEL, "loss"))

        # Read a list of metrics function names, ant eval each of the list,
        # metrics = ["binary_accuracy"]
        metrics  = self.config.get(MODEL, "metrics")
        self.metrics = []
        for metric in metrics:
            self.metrics.append(eval(metric))

        print("--- loss    {}".format(self.loss))
        print("--- metrics {}".format(self.metrics))

        #self.model.trainable = self.trainable

        self.model.compile(optimizer = self.optimizer, loss= self.loss, metrics = self.metrics)

        show_summary = self.config.get(MODEL, "show_summary")
        if show_summary:
            self.model.summary()
            
            
    def set_seed(self, seed=137):
        print("=== set seed {}".format(seed))
        random.seed    = seed
        np.random.seed = seed
        tf.random.set_seed(seed)
        
    def create(self, num_classes, image_height, image_width, image_channels,
               base_filters=16, num_layers=5):
        # inputs
        print("Input image_height {} image_width {} image_channels {}".format(image_height, image_width, image_channels))
        inputs = Input((image_height, image_width, image_channels))
        s = Lambda(lambda x: x / 255)(inputs)

        # normalization is False on default.
        normalization = self.config.get(MODEL, "normalization", dvalue=False)
        print("--- normalization {}".format(normalization))
        # Encoder
        dropout_rate = self.config.get(MODEL, "dropout_rate")
        enc = []
        kernel_size = (3, 3)
        pool_size = (2, 2)
        dilation = (2, 2)
        strides = (1, 1)
        # <experiment on="2023/06/20">
        # [model]
        # Specify a tuple of base kernel size of odd number something like this:
        # base_kernels = (5,5)
        base_kernels = self.config.get(MODEL, "base_kernels", dvalue=(3, 3))
        (k, k) = base_kernels
        kernel_sizes = []
        for n in range(num_layers):
            kernel_sizes += [(k, k)]
            k -= 2
            if k < 3:
                k = 3
        rkernel_sizes = kernel_sizes[::-1]
        rkernel_sizes = rkernel_sizes[1:]
        # kernel_sizes will become a list [(7,7),(5,5), (3,3),(3,3)...] if base_kernels were (7,7)
        print("--- kernel_size   {}".format(kernel_sizes))
        print("--- rkernel_size  {}".format(rkernel_sizes))
        # </experiment>
        try:
            dilation_ = self.config.get(MODEL, "dilation")
            (d1, d2) = dilation_
            if d1 == d2:
                dilation = dilation_
        except:
            pass

        dilations = []
        (d, d) = dilation
        for n in range(num_layers):
            dilations += [(d, d)]
            d -= 1
            if d < 1:
                d = 1
        rdilations = dilations[::-1]
        rdilations = rdilations[1:]

        print("=== dilations  {}".format(dilations))
        print("=== rdilations {}".format(rdilations))

        for i in range(num_layers):
            filters = base_filters * (2**i)
            kernel_size = kernel_sizes[i]
            dilation = dilations[i]

            print("--- kernel_size {}".format(kernel_size))
            print("--- dilation {}".format(dilation))

            c = Conv2D(filters, kernel_size, strides=strides, activation=relu,
                       kernel_initializer='he_normal', dilation_rate=dilation, padding='same')(s)
            if normalization:
                c = BatchNormalization()(c)
            c = Dropout(dropout_rate * i)(c)
            c = Conv2D(filters, kernel_size, strides=strides, activation=relu,
                       kernel_initializer='he_normal', dilation_rate=dilation, padding='same')(c)
            if normalization:
                c = BatchNormalization()(c)
            if i < (num_layers - 1):
                p = MaxPool2D(pool_size=pool_size)(c)
                s = p
            enc.append(c)

        enc_len = len(enc)
        enc.reverse()
        n = 0
        c = enc[n]

        # --- Decoder
        for i in range(num_layers - 1):
            kernel_size = rkernel_sizes[i]
            dilation = rdilations[i]
            print("+++ kernel_size {}".format(kernel_size))
            print("+++ dilation {}".format(dilation))

            f = enc_len - 2 - i
            filters = base_filters * (2**f)
            u = Conv2DTranspose(filters, (2, 2), strides=(2, 2), padding='same')(c)
            n += 1
            u = concatenate([u, enc[n]])
            u = Conv2D(filters, kernel_size, strides=strides, activation=relu,
                       kernel_initializer='he_normal', dilation_rate=dilation, padding='same')(u)
            # 2023/06/20
            if normalization:
                u = BatchNormalization()(u)
            u = Dropout(dropout_rate * f)(u)
            u = Conv2D(filters, kernel_size, strides=strides, activation=relu,
                       kernel_initializer='he_normal', dilation_rate=dilation, padding='same')(u)
            # 2023/06/25
            if normalization:
                u = BatchNormalization()(u)
            c = u

        # outouts
        outputs = Conv2D(num_classes, (1, 1), activation='sigmoid')(c)

        # create Model
        model = Model(inputs=[inputs], outputs=[outputs])

        return model

    def create_dirs(self, eval_dir, model_dir):
        # 2023/06/20
        dt_now = str(datetime.datetime.now())
        dt_now = dt_now.replace(":", "_").replace(" ", "_")
        create_backup = self.config.get(TRAIN, "create_backup", False)
        if os.path.exists(eval_dir):
            # if create_backup flag is True, move previous eval_dir to *_bak
            if create_backup:
                moved_dir = eval_dir + "_" + dt_now + "_bak"
                shutil.move(eval_dir, moved_dir)
                print("--- Moved to {}".format(moved_dir))
            else:
                shutil.rmtree(eval_dir)

        if not os.path.exists(eval_dir):
            os.makedirs(eval_dir)

        if os.path.exists(model_dir):
            # if create_backup flag is True, move previous model_dir to *_bak
            if create_backup:
                moved_dir = model_dir + "_" + dt_now + "_bak"
                shutil.move(model_dir, moved_dir)
                print("--- Moved to {}".format(moved_dir))
            else:
                shutil.rmtree(model_dir)

        if not os.path.exists(model_dir):
            os.makedirs(model_dir)
            
    def train(self, x_train, y_train):
        batch_size = self.config.get(TRAIN, "batch_size")
        epochs = self.config.get(TRAIN, "epochs")
        patience = self.config.get(TRAIN, "patience")
        eval_dir = self.config.get(TRAIN, "eval_dir")
        model_dir = self.config.get(TRAIN, "model_dir")
        metrics = ["accuracy", "val_accuracy"]
        try:
            metrics = self.config.get(TRAIN, "metrics")
        except:
            pass
        self.create_dirs(eval_dir, model_dir)
        # Copy current config_file to model_dir
        shutil.copy2(self.config_file, model_dir)
        print("-- Copied {} to {}".format(self.config_file, model_dir))

        weight_filepath = os.path.join(model_dir, BEST_MODEL_FILE)

        early_stopping = EarlyStopping(patience=patience, verbose=1)
        check_point = ModelCheckpoint(weight_filepath, verbose=1, save_best_only=True)
        epoch_change = EpochChangeCallback(eval_dir, metrics)

        results = self.model.fit(x_train, y_train,
                                 validation_split=0.2, batch_size=batch_size, epochs=epochs,
                                 shuffle=False,
                                 callbacks=[early_stopping, check_point, epoch_change],
                                 verbose=1)

    def load_model(self):
        rc = False
        if not self.model_loaded:
            model_dir = self.config.get(TRAIN, "model_dir")
            weight_filepath = "./models/best_model.h5"
            if os.path.exists(weight_filepath):
                self.model.load_weights(weight_filepath)
                self.model_loaded = True
                print("=== Loaded a weight_file {}".format(weight_filepath))
                rc = True
            else:
                message = "Not found a weight_file " + weight_filepath
                raise Exception(message)
        else:
            pass
            # print("== Already loaded a weight file.")
        return rc
    
    def infer(self, input_dir, output_dir, expand=True):
        writer = GrayScaleImageWriter()
        # We are interested in png and jpg files.
        image_files = glob.glob(input_dir + "/*.png")
        image_files += glob.glob(input_dir + "/*.jpg")
        image_files += glob.glob(input_dir + "/*.tif")
        # 2023/05/15 Added *.bmp files
        image_files += glob.glob(input_dir + "/*.bmp")

        width = self.config.get(MODEL, "image_width")
        height = self.config.get(MODEL, "image_height")

        merged_dir = None
        try:
            merged_dir = self.config.get(INFER, "merged_dir")
            if os.path.exists(merged_dir):
                shutil.rmtree(merged_dir)
            if not os.path.exists(merged_dir):
                os.makedirs(merged_dir)
        except:
            pass

        for image_file in image_files:
            basename = os.path.basename(image_file)
            name = basename.split(".")[0]
            # <fixed> 2023/07/15 to avoid error on png file
            img = cv2.imread(image_file)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            # </fixed>

            h = img.shape[0]
            w = img.shape[1]
            # Any way, we have to resize the input image to match the input size of our TensorflowUNet model.
            img = cv2.resize(img, (width, height))
            predictions = self.predict([img], expand=expand)
            prediction = predictions[0]
            image = prediction[0]
            # Resize the predicted image to be the original image size (w, h), and save it as a grayscale image.
            # Probably, this is a natural way for all humans.
            mask = writer.save_resized(image, (w, h), output_dir, name)

            print("--- image_file {}".format(image_file))
            if merged_dir is not None:
                # Resize img to the original size (w, h)
                img = cv2.resize(img, (w, h))
                img += mask
                merged_file = os.path.join(merged_dir, basename)
                cv2.imwrite(merged_file, img)
                
    def predict(self, images, expand=True):
        self.load_model()
        predictions = []
        for image in images:
            # print("=== Input image shape {}".format(image.shape))
            if expand:
                image = np.expand_dims(image, 0)
            pred = self.model.predict(image)
            predictions.append(pred)
        return predictions

    def infer_tiles(self, input_dir, output_dir, expand=True):
        image_files = glob.glob(input_dir + "/*.png")
        image_files += glob.glob(input_dir + "/*.jpg")
        image_files += glob.glob(input_dir + "/*.tif")
        image_files += glob.glob(input_dir + "/*.bmp")
        MARGIN = self.config.get(TILEDINFER, "overlapping", dvalue=0)
        print("MARGIN {}".format(MARGIN))

        merged_dir = None
        try:
            merged_dir = self.config.get(TILEDINFER, "merged_dir")
            if os.path.exists(merged_dir):
                shutil.rmtree(merged_dir)
            if not os.path.exists(merged_dir):
                os.makedirs(merged_dir)
        except:
            pass
        split_size = self.config.get(MODEL, "image_width")
        print("---split_size {}".format(split_size))

        for image_file in image_files:
            image = Image.open(image_file)
            w, h = image.size

            vert_split_num = h // split_size
            if h % split_size != 0:
                vert_split_num += 1

            horiz_split_num = w // split_size
            if w % split_size != 0:
                horiz_split_num += 1

            bgcolor = self.config.get(TILEDINFER, "background", dvalue=0)
            print("=== bgcolor {}".format(bgcolor))
            background = Image.new("L", (w, h), bgcolor)

            for j in range(vert_split_num):
                for i in range(horiz_split_num):
                    left = split_size * i
                    upper = split_size * j
                    right = left + split_size
                    lower = upper + split_size

                    if left >= w or upper >= h:
                        continue

                    left_margin = MARGIN
                    upper_margin = MARGIN
                    if left - MARGIN < 0:
                        left_margin = 0
                    if upper - MARGIN < 0:
                        upper_margin = 0

                    cropped = image.crop((left - left_margin, upper - upper_margin, right + MARGIN, lower + MARGIN))
                    cw, ch = cropped.size
                    cropped = cropped.resize((split_size, split_size))
                    predictions = self.predict([cropped], expand=expand)
                    prediction = predictions[0]
                    mask = prediction[0]

                    img = self.mask_to_image(mask)
                    img = img.resize((cw, ch))

                    img = img.convert("L")

                    ww, hh = img.size
                    img = img.crop((left_margin, upper_margin, ww - left_margin, hh - upper_margin))

                    ww, hh = img.size
                    background.paste(img, (left, upper))
                    print("---paste j:{} i:{} ww:{} hh:{}".format(j, i, ww, hh))

            basename = os.path.basename(image_file)
            output_file = os.path.join(output_dir, basename)
            background.save(output_file)

            if merged_dir is not None:
                img = np.array(image)
                img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
                mask = np.array(background)
                mask = cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR)
                img += mask

                merged_file = os.path.join(merged_dir, basename)
                cv2.imwrite(merged_file, img)

    def mask_to_image(self, data, factor=255.0):
        h = data.shape[0]
        w = data.shape[1]

        data = data * factor
        data = data.reshape([w, h])
        image = Image.fromarray(data)

        return image


    def evaluate(self, x_test, y_test):
        self.load_model()
        score = self.model.evaluate(x_test, y_test, verbose=1)
        print("Test loss    :{}".format(round(score[0], 4)))
        print("Test accuracy:{}".format(round(score[1], 4)))


    def inspect(self, image_file='./model.png', summary_file="./summary.txt"):

        tf.keras.utils.plot_model(self.model, to_file=image_file, show_shapes=True)
        print("=== Saved model graph as an image_file {}".format(image_file))
        with open(summary_file, 'w') as f:
            self.model.summary(print_fn=lambda x: f.write(x + '\n'))
        print("=== Saved model summary as a text_file {}".format(summary_file))

In [8]:
config_file = "./train_eval_infer.config"
if len(sys.argv) == 2:
    config_file = sys.argv[1]
config = ConfigParser(config_file)
images_dir = config.get(INFER, "images_dir")
output_dir = config.get(INFER, "output_dir")

model = TensorflowUNet(config_file)

if not os.path.exists(images_dir):
    raise Exception("Not found " + images_dir)

if os.path.exists(output_dir):
    shutil.rmtree(output_dir)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

model.infer(images_dir, output_dir, expand=True)

==== ConfigParser ./train_eval_infer.config
{'eval': {'image_datapath': '"./Mammogram/valid/images"',
          'mask_datapath': '"./Mammogram/valid/masks"',
          'output_dir': '"./eval_output"'},
 'infer': {'images_dir': '".//Mammogram/test/images"',
           'merged_dir': '"./mini_test_output_merged"',
           'output_dir': '"./mini_test_output"'},
 'mask': {'binarize': 'True', 'blur': 'True', 'threshold': '74'},
 'model': {'base_filters': '16',
           'base_kernels': '(7,7)',
           'clipvalue': '0.5',
           'dilation': '(2,2)',
           'dropout_rate': '0.06',
           'image_channels': '3',
           'image_height': '512',
           'image_width': '512',
           'learning_rate': '0.0001',
           'loss': '"bce_iou_loss"',
           'metrics': '["iou_coef"]',
           'num_classes': '1',
           'num_layers': '7',
           'show_summary': 'False'},
 'train': {'batch_size': '2',
           'create_backup': 'False',
           'epochs': '1',

 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P299_R_CM_CC.jpg
--- image_file .//Mammogram/test/images\flipped_P299_R_CM_CC.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P310_L_CM_CC.jpg
--- image_file .//Mammogram/test/images\flipped_P310_L_CM_CC.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P312_L_DM_MLO.jpg
--- image_file .//Mammogram/test/images\flipped_P312_L_DM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P314_L_DM_MLO.jpg
--- image_file .//Mammogram/test/images\flipped_P314_L_DM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P318_R_CM_CC.jpg
--- image_file .//Mammogram/test/images\flipped_P318_R_CM_CC.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\flipped_P321_R_DM_MLO.jpg
--- image_file .//Mammogram/test/images\flipped_P321_R_DM_MLO.jpg
 ima

 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\mirrored_P85_R_DM_MLO.jpg
--- image_file .//Mammogram/test/images\mirrored_P85_R_DM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P113_L_DM_CC.jpg
--- image_file .//Mammogram/test/images\P113_L_DM_CC.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P12_L_CM_MLO.jpg
--- image_file .//Mammogram/test/images\P12_L_CM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P137_L_CM_MLO.jpg
--- image_file .//Mammogram/test/images\P137_L_CM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P138_R_CM_CC.jpg
--- image_file .//Mammogram/test/images\P138_R_CM_CC.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P165_L_DM_MLO.jpg
--- image_file .//Mammogram/test/images\P165_L_DM_MLO.jpg
 image w: 512 h: 512
== resized to (512, 512)
=== Saved ./mini_test_output\P173_L_CM