In [1]:
# TBD 1 : logger 추가
# TBD 2: flask github 참고, method, class, 파일의 맨 윗단 마다 pydoc 형식으로 달기
# TBD 3: 축약어를 자제할것 (특히 변수)

# -------------------------
#   done
# -------------------------

# 0. add data-setter, receiver system use python queue.Queue() class
# this will resolve i/o bottleneck
# 3. make iterable

# -------------------------
#   In Progress
# -------------------------

# 1. add logger
# 2. make image drawer overlay mask on image

# -------------------------
#   To be Done
# -------------------------

# 4. make verbose turn on and off
# 5. write pydoc

# python basic Module
import os
import sys
import types
import progressbar
from datetime import datetime
from shutil import copy
from pickle import dump, load

# math, image, plot Module
import numpy as np
import cv2
import matplotlib.pyplot as plt  # TBD

# tensorflow Module
import tensorflow as tf
from tensorflow.keras import backend as keras_backend
from tensorflow.keras.layers import GaussianNoise
from tensorflow.keras.layers import Input, Concatenate
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Nadam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import losses

# keras segmentaion third-party Moudle
import segmentation_models as sm
import tensorflow_addons as tfa

# custom Module
from gan_module.data_loader.medical_segmentation_data_loader import DataLoader
from gan_module.data_loader.manage_batch import BatchQueueManager

from gan_module.model.build_model import build_generator_non_unet as build_generator
from gan_module.model.build_model import build_discriminator as build_discriminator
from gan_module.util.custom_loss import weighted_region_loss, dice_score, combined_loss, f1_loss
# from gan_module.util.custom_gradient import SGD_AGC
from gan_module.util.manage_learning_rate import learning_rate_scheduler
from gan_module.util.draw_images import ImageDrawer
from gan_module.util.logger import TrainLogger
from gan_module.config import CONFIG

USE_GPU = True

if USE_GPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    gpu_devices = tf.config.experimental.list_physical_devices("GPU")
    for device in gpu_devices:
        tf.config.experimental.set_memory_growth(device, True)
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'


class Pix2PixSegmentation:
    def __init__(
        self,
        generator_power=32,
        discriminator_power=32,
        generator_depth = None,
        discriminator_depth = None,
        generator_learning_rate=1e-4,
        discriminator_learning_rate=1e-4,
        temp_weights_path=".",
        on_memory=True,
        code_test=False
    ):
        # Input shape
        img_shape = CONFIG["img_shape"]
        input_channels = CONFIG["input_channels"]
        output_channels = CONFIG["output_channels"]

        self.input_img_shape = (*img_shape, input_channels)
        self.output_img_shape = (*img_shape, output_channels)
        # set parameter
        self.start_epoch = None
        self.on_memory = on_memory
        self.history = {"generator_loss": [],
                        "f1_loss_train": [], "f1_score_train": [],
                        "f1_loss_valid": [], "f1_score_valid": []}
        self.temp_weights_path = temp_weights_path

        # Configure data loader
        self.dataset_name = "glomerulus_0.65_512_not_filped_original"
        self.data_loader = DataLoader(
            dataset_name=self.dataset_name,
            config_dict=CONFIG,
            on_memory=self.on_memory, 
            code_test=code_test
        )
        
        self.train_logger = TrainLogger()
        
        self.loaded_data_index = {
            "train": np.arange(self.data_loader.data_length["train"]),
            "valid": np.arange(self.data_loader.data_length["valid"])
        }
        
        # Configure Image Drawer
        self.image_drawer = ImageDrawer(
            dataset_name=self.dataset_name, data_loader=self.data_loader
        )
        self.discriminator_loss_ratio = keras_backend.variable(0.1)
        self.f1_loss_ratio = keras_backend.variable(100)
        self.discriminator_losses = np.array(
            [1 for _ in range(self.data_loader.data_length["train"])], dtype=np.float32)
        self.discriminator_acc_previous = 0.5
        self.discriminator_acces = np.array(
            [0.5 for _ in range(self.data_loader.data_length["train"])])
        self.discriminator_acces_previous = self.discriminator_acces.copy()
        self.generator_losses = np.array(
            [1 for _ in range(self.data_loader.data_length["train"])], dtype=np.float32)
        self.generator_losses_previous = self.generator_losses.copy()
        self.generator_f1_losses = np.array(
            [1 for _ in range(self.data_loader.data_length["train"])], dtype=np.float32)
        self.generator_loss_min = 10000
        self.generator_loss_previous = 10000
        self.generator_loss_max_previous = 10000
        self.generator_valid_loss_min = 10000
        self.total_f1_loss_min = 2
        self.weight_save_stack = False
        self.training_end_stack = 0
        # Calculate output shape of D (PatchGAN)
        self.disc_patch = (img_shape[0] // (2 ** discriminator_depth), img_shape[1] // (2 ** discriminator_depth), 1)
        # Number of filters in the first layer of G and D
        self.generator_learning_rate = generator_learning_rate
        self.discriminator_learning_rate = discriminator_learning_rate
        self.patience_count = 0
        
        generator_optimizer = Nadam(self.generator_learning_rate)
        discriminator_optimizer = Nadam(self.discriminator_learning_rate)
#         generator_optimizer = SGD_AGC(lr=self.generator_learning_rate, momentum=0.9)
#         discriminator_optimizer = SGD_AGC(lr=self.discriminator_learning_rate, momentum=0.9)        
        # Build the generator
        self.generator = build_generator(
            input_img_shape=self.input_img_shape,
            output_channels=output_channels,
            generator_power=generator_power,
            depth=generator_depth,
        )
        self.generator.compile(
            loss=weighted_region_loss,
            optimizer=generator_optimizer,
            metrics=[dice_score],
        )
        # Build and compile the discriminator
        self.discriminator = build_discriminator(
            input_img_shape=self.input_img_shape,
            output_img_shape=self.output_img_shape,
            discriminator_power=discriminator_power,
            depth=discriminator_depth,
        )
        # 'mse' or tf.keras.losses.Huber() tf.keras.losses.LogCosh()
        self.discriminator.compile(
            loss=sm.losses.BinaryFocalLoss(alpha=0.25, gamma=4),
            optimizer=discriminator_optimizer,
            metrics=["accuracy"],
        )
        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # -------------------------
        # Construct Computational
        #   Graph of Generator
        # -------------------------

        # Input images and their conditioning images
        original_img = Input(shape=self.input_img_shape)
        # generate image from original_img for target masked_img
        model_masked_img = self.generator(original_img)
        
    
        # Discriminators determines validity of translated images / condition pairs
        model_validity = self.discriminator([original_img, model_masked_img])
        # give score by
        # 1. how generator trick discriminator
        # 2. how generator's image same as real photo in pixel
        # 3. if you want change loss, see doc https://keras.io/api/losses/
        # 4. 'mse', 'mae', tf.keras.losses.LogCosh(),  tf.keras.losses.Huber()
        self.combined = Model(
            inputs=original_img,
            outputs=[model_validity, model_masked_img],
        )
        
        self.combined.compile(
            loss=[
#                 tf.keras.losses.BinaryCrossentropy(label_smoothing=0.05),
                sm.losses.BinaryFocalLoss(alpha=0.25, gamma=4),
                weighted_region_loss
            ],
            loss_weights=[0.1, 100],
            optimizer=generator_optimizer,
        )

    def train(self, epochs, batch_size=1, epoch_shuffle_term=10):

        start_time = datetime.now()

        # Adversarial loss ground truths
        self.training_end_stack = 0
        self.batch_size = batch_size
        valid_patch = np.ones(
            (self.batch_size, *self.disc_patch), dtype=np.float32)
        fake_patch = np.zeros(
            (self.batch_size, *self.disc_patch), dtype=np.float32)
        # TBD : move batch_queue_manager to __init__
        self.batch_queue_manager = BatchQueueManager(self, batch_size, self.on_memory)
        
        if self.start_epoch is None:
            self.start_epoch = 0
        for epoch in range(self.start_epoch, epochs):
            batch_i = 0

            generator_loss_max_in_epoch = 0
            generator_loss_min_in_epoch = 1000
            generator_discriminator_losses = np.array(
            [1 for _ in range(self.data_loader.data_length["train"])], dtype=np.float32)
            # shffle data maybe
            if epoch % epoch_shuffle_term == 0:
                self.data_loader.shuffle_train_imgs()
            
            if self.discriminator_acc_previous < 0.95:
                discriminator_learning = True
            else:
                discriminator_learning = False
            generator_1_10_quantile = np.quantile(self.generator_losses, 0.1)
            
            
            generator_current_learning_rate = learning_rate_scheduler(
                self.generator_learning_rate,
                epoch+self.patience_count,
                warm_up=True
            )
            discriminator_current_learning_rate = learning_rate_scheduler(
                self.discriminator_learning_rate,
                epoch+self.patience_count,
                warm_up=True
            ) * (1 - self.discriminator_acc_previous)
            keras_backend.set_value(
                self.discriminator.optimizer.learning_rate,
                discriminator_current_learning_rate,
            )
            keras_backend.set_value(
                self.discriminator.optimizer.learning_rate,
                discriminator_current_learning_rate,
            )
            keras_backend.set_value(
                self.discriminator_loss_ratio,
                keras_backend.variable(0.01) + 0.5 * self.discriminator_acc_previous,
            )
            keras_backend.set_value(
                self.f1_loss_ratio,
                keras_backend.variable(100) - 0.5  * self.discriminator_acc_previous,
            )
            
            bar = progressbar.ProgressBar(
                maxval=self.data_loader.data_length["train"]).start()
            
            while batch_i + self.batch_size < self.data_loader.data_length["train"] + self.batch_size:

                batch_index = self.loaded_data_index["train"][batch_i: batch_i +
                                                              self.batch_size]

                original_img, masked_img = self.batch_queue_manager.get_batch(
                    data_mode="train")
                model_masked_img = self.generator.predict_on_batch(
                    original_img)
                
                valid_patch = np.ones(
                    (len(model_masked_img), *self.disc_patch), dtype=np.float32)
                fake_patch = np.zeros(
                    (len(model_masked_img), *self.disc_patch), dtype=np.float32)
                
                self.original_img = original_img
                self.masked_img = masked_img
                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Train Discriminator for valid image if it failed to detect fake image
                if discriminator_learning and self.discriminator_acc_previous < np.random.rand():
                    discriminator_loss = self.discriminator.train_on_batch([original_img, masked_img], valid_patch)
                else:
                    discriminator_loss = self.discriminator.test_on_batch([original_img, masked_img], valid_patch)
                    
                batch_discriminator_acc_previous = np.mean(
                    self.discriminator_acces_previous[batch_index])
                self.discriminator.trainable = False
                # -----------------
                #  Train Generator
                # -----------------
                
#                 if np.mean(self.generator_losses[batch_index]) >= generator_1_10_quantile:
#                     generator_loss = self.combined.train_on_batch(
#                         original_img,
#                         [valid_patch, masked_img]
#                     )
#                 else:
#                     generator_loss = self.combined.test_on_batch(
#                         original_img,
#                         [valid_patch, masked_img]
#                     )
                generator_loss = self.combined.train_on_batch(
                    original_img,
                    [valid_patch, masked_img]     
                )
                # train discriminator for fake image if it failed to detect fake image
                self.discriminator.trainable = True
                if (batch_discriminator_acc_previous <= 0.5 or epoch == 0) and discriminator_learning:
                    discriminator_loss += self.discriminator.train_on_batch(
                        [original_img,model_masked_img], fake_patch)
                else:
                    discriminator_loss += self.discriminator.test_on_batch(
                        [original_img,model_masked_img], fake_patch)

                self.discriminator_losses[batch_index] = discriminator_loss[0]
                self.discriminator_acces[batch_index] = discriminator_loss[1]
                self.generator_losses[batch_index] = generator_loss[0]
                self.generator_f1_losses[batch_index] = generator_loss[2]
                generator_discriminator_losses = generator_loss[1]
                # plot progress
                bar.update(batch_i)
                
                # 한 배치 끝
                batch_i += self.batch_size
            
            # training batch 사이클 끝
            
            #######################################
            # valid_loss, valid score 계산
            #######################################
            
            #if self.generator_loss_min > np.mean(self.generator_losses):
            valid_f1_loss_list = []
            valid_f1_score_list = []
            for index in range(0, self.data_loader.data_length["valid"], self.batch_size):

                valid_source_img, valid_masked_img = self.batch_queue_manager.get_batch(
                    data_mode="valid")

                valid_model_masked_img = self.generator.predict_on_batch(
                    valid_source_img)
                valid_f1_loss =  weighted_region_loss(valid_masked_img, valid_model_masked_img)
                valid_f1_score = dice_score(valid_masked_img, valid_model_masked_img)
                
                valid_f1_loss_list.append(valid_f1_loss)
                valid_f1_score_list.append(valid_f1_score)
            
            current_valid_f1_loss = np.mean(valid_f1_loss_list)
            
            # compute valid_f1_loss end    
            total_f1_loss = np.mean(self.generator_f1_losses) + current_valid_f1_loss
                
            self.discriminator_acc_previous = np.mean(self.discriminator_acces)
            self.discriminator_acces_previous = self.discriminator_acces.copy()
            self.generator_losses_previous = self.generator_losses.copy()
            # TBD: add epoch bigger than history length
            self.history["generator_loss"].append(
                np.mean(self.generator_losses))
            self.history["f1_loss_train"].append(
                np.mean(self.generator_f1_losses))
            self.history["f1_loss_valid"].append(
                np.mean(valid_f1_loss_list))
            

            self.image_drawer.sample_images(
                self.generator, epoch)

            # previous generator_loss 갱신
            self.generator_loss_previous = np.mean(self.generator_losses)
            self.generator_loss_max_previous = np.max(self.generator_losses)
            
            
            #######################################
            # 학습 상태 및 로그 출력
            #######################################
            self.train_logger.write_log(
                f"{epoch}/{epochs} ({epoch+self.patience_count})",
                np.mean(self.discriminator_acces),
                np.mean(self.generator_losses),
                np.max(self.generator_losses),
                np.min(self.generator_losses),
                f"{self.generator_loss_min - np.mean(self.generator_losses)}({np.mean(self.generator_losses) / self.generator_loss_min})",
                self.generator_loss_min,
                generator_current_learning_rate,
                datetime.now() - start_time
            )
            print(f"valid_loss : {self.generator_valid_loss_min} / {current_valid_f1_loss}")
            print(f"discriminator_loss : {np.mean(self.discriminator_losses)}")
            print(f"generator_discriminator_loss : {np.mean(generator_discriminator_losses)}")
            print(f"train_f1_loss : {np.mean(self.generator_f1_losses)}")
            print(f"valid_f1_score : {np.mean(valid_f1_score_list)}")
            print(f"current/min total_f1_loss = {self.total_f1_loss_min} / {total_f1_loss}")
            #######################################
            # 학습 상태 관찰 후 저장 여부 선택
            #######################################
            
            if current_valid_f1_loss / self.generator_valid_loss_min < 1.1:
                
                if self.generator_valid_loss_min > current_valid_f1_loss:
        
                    self.generator_valid_loss_min = current_valid_f1_loss
            
                    if self.generator_loss_min > np.mean(self.generator_losses):
                        self.generator_loss_min = np.mean(self.generator_losses)
                        self.generator_loss_max_min = generator_loss_max_in_epoch
                        self.generator_loss_min_min = generator_loss_min_in_epoch
                    
                    self.save_study_info()
                    self.weight_save_stack = True
                    print("save weights")    
                    
                if self.total_f1_loss_min > total_f1_loss: 
                    self.total_f1_loss_min = total_f1_loss
                
            else:
                print("loss decrease.")
                if epoch+self.patience_count < 20:
                    self.patience_count -= 1
                else:
                    self.patience_count += 1
                self.load_best_weights()
            
            if epoch >= 10 and self.weight_save_stack:
                copy(
                    "generator.h5",
                    "./generator_weights/generator_"
                    + str(round(self.generator_loss_min, 5))
                    + "_"
                    + str(round(self.generator_loss_max_min, 5))
                    + ".h5",
                )
                self.weight_save_stack = False
            
            # 한 epoch의 끝
        
            
    def get_info_folderPath(self):
        return (
            str(round(self.generator_loss_min, 5))
            + "_"
            + str(round(self.generator_loss_max_min, 5))
        )

    def save_study_info(self, path=None):

        if path is None:
            path = self.temp_weights_path

        generator_weigth_path = os.path.join(path, "generator.h5")
        discriminator_weigth_path = os.path.join(path, "discriminator.h5")
        combined_weigth_path = os.path.join(path, "combined.h5")

        self.generator.save_weights(generator_weigth_path)
        self.discriminator.save_weights(discriminator_weigth_path)
        self.combined.save_weights(combined_weigth_path)

        study_info = {}
        study_info["start_epoch"] = self.start_epoch
        study_info["train_loaded_data_index"] = self.loaded_data_index["train"]
        study_info["generator_loss_min"] = self.generator_loss_min
        study_info["generator_loss_max_min"] = self.generator_loss_max_min
        study_info["generator_loss_min_min"] = self.generator_loss_min_min
        study_info["generator_losses_previous"] = self.generator_losses_previous
        study_info["discriminator_acces"] = self.discriminator_acces
        study_info["history"] = self.history
        file = open(path + "/study_info.pkl", "wb")
        dump(study_info, file)
        file.close()

    def load_study_info(self):

        self.generator.load_weights("generator.h5")
        self.discriminator.load_weights("discriminator.h5")
#         self.combined.load_weights("combined.h5")

        if os.path.isfile("study_info.pkl"):
            file = open("study_info.pkl", "rb")
            study_info = load(file)
            file.close()
            self.start_epoch = study_info["start_epoch"]
            self.loaded_data_index["train"] = study_info["train_loaded_data_index"]
            self.generator_loss_min = study_info["generator_loss_min"]
            self.generator_loss_max_min = study_info["generator_loss_max_min"]
            self.generator_loss_min_min = study_info["generator_loss_min_min"]
            self.generator_losses_previous = study_info["generator_losses_previous"]
            self.discriminator_acces = study_info["discriminator_acces"]
            self.history = study_info["history"]
        else:
            print("No info pkl file!")

    def load_best_weights(self):
        self.generator.load_weights(self.temp_weights_path + "/generator.h5")
        self.discriminator.load_weights(
            self.temp_weights_path + "/discriminator.h5")
        self.combined.load_weights(self.temp_weights_path + "/combined.h5")

    def run_pretraining(self, epochs):
        if self.on_memory:
            self.generator.fit(
                x=self.data_loader.loaded_data_object["train"]["input"],
                y=self.data_loader.loaded_data_object["train"]["output"],
                validation_data=list(self.data_loader.loaded_data_object["valid"].values()),
                batch_size=self.batch_size, epochs=epochs
            )
        else:
            self.generator.fit_generator(
                x=self.data_loader.loaded_data_object["train"]["input"],
                y=self.data_loader.loaded_data_object["train"]["output"],
                validation_data=list(self.data_loader.loaded_data_object["valid"].values()),
                batch_size=self.batch_size, epochs=epochs
            )
        self.generator.save_weights("pretrained.h5")

Segmentation Models: using `tf.keras` framework.
{'img_shape': [512, 512], 'input_channels': 3, 'output_channels': 1}


In [2]:
generator_lr = 1e-4
discriminator_lr = 1e-4
batch_size = 4

g_lr = generator_lr * batch_size
d_lr = discriminator_lr * batch_size
gan = Pix2PixSegmentation(generator_power=8, discriminator_power=8, 
                          generator_depth = 3, discriminator_depth = 3,
                          generator_learning_rate=g_lr, discriminator_learning_rate=d_lr,
                          on_memory=False, code_test=False)

In [3]:
# gan.load_study_info()
# gan.start_epoch = 8
gan.train(epochs=200, batch_size=batch_size, epoch_shuffle_term=50)
# gan.train(epochs=20, batch_size=batch_size, epoch_shuffle_term=100)

 99% (4840 of 4841) |################### | Elapsed Time: 0:20:56 ETA:   0:00:00{
2021-05-25 00:54:04,516 - train - INFO - 
Epoch : 0/200 (0)
Discriminator_acces : 0.21222860599888968
Mean generator loss : 58.56317901611328
Max generator loss : 96.169189453125
Min generator loss : 41.4285888671875
Generator loss decrease : 9941.436820983887(0.005856317901611328)
Current lowest generator loss : 10000
Current Learning_rate : 2e-05
Elapsed_time : 0:21:12.521678
}


valid_loss : 10000 / 0.45354944467544556
discriminator_loss : 2.759866952896118
generator_discriminator_loss : 0.015859050676226616
train_f1_loss : 0.5010516047477722
valid_f1_score : 0.19480840861797333
current/min total_f1_loss = 2 / 0.9546010494232178


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:18 ETA:   0:00:00{
2021-05-25 01:15:40,841 - train - INFO - 
Epoch : 1/200 (1)
Discriminator_acces : 0.08862293340331931
Mean generator loss : 40.7171516418457
Max generator loss : 62.803009033203125
Min generator loss : 29.49576187133789
Generator loss decrease : 17.846027374267578(0.6952688097953796)
Current lowest generator loss : 58.56317901611328
Current Learning_rate : 4e-05
Elapsed_time : 0:42:48.847534
}


valid_loss : 0.45354944467544556 / 0.3700481355190277
discriminator_loss : 1.1551578044891357
generator_discriminator_loss : 0.012943366542458534
train_f1_loss : 0.3747846484184265
valid_f1_score : 0.46006709337234497
current/min total_f1_loss = 0.9546010494232178 / 0.7448327541351318


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:46 ETA:   0:00:00{
2021-05-25 01:37:44,544 - train - INFO - 
Epoch : 2/200 (2)
Discriminator_acces : 0.057181342774446135
Mean generator loss : 31.112979888916016
Max generator loss : 56.12519073486328
Min generator loss : 22.344329833984375
Generator loss decrease : 9.604171752929688(0.7641246914863586)
Current lowest generator loss : 40.7171516418457
Current Learning_rate : 6e-05
Elapsed_time : 1:04:52.549897
}


valid_loss : 0.3700481355190277 / 0.3003856837749481
discriminator_loss : 0.30137890577316284
generator_discriminator_loss : 0.01338400412350893
train_f1_loss : 0.2930181622505188
valid_f1_score : 0.7430033683776855
current/min total_f1_loss = 0.7448327541351318 / 0.5934038162231445


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:53 ETA:   0:00:00{
2021-05-25 01:59:55,324 - train - INFO - 
Epoch : 3/200 (3)
Discriminator_acces : 0.08487225307755371
Mean generator loss : 24.456432342529297
Max generator loss : 53.69671630859375
Min generator loss : 17.22233009338379
Generator loss decrease : 6.656547546386719(0.786052405834198)
Current lowest generator loss : 31.112979888916016
Current Learning_rate : 8e-05
Elapsed_time : 1:27:03.329202
}


valid_loss : 0.3003856837749481 / 0.22310322523117065
discriminator_loss : 0.09448113292455673
generator_discriminator_loss : 0.014721756801009178
train_f1_loss : 0.23032088577747345
valid_f1_score : 0.6931137442588806
current/min total_f1_loss = 0.5934038162231445 / 0.4534240961074829


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:46 ETA:   0:00:00{
2021-05-25 02:21:58,797 - train - INFO - 
Epoch : 4/200 (4)
Discriminator_acces : 0.06460926203263788
Mean generator loss : 19.360321044921875
Max generator loss : 55.410850524902344
Min generator loss : 12.957114219665527
Generator loss decrease : 5.096111297607422(0.791624903678894)
Current lowest generator loss : 24.456432342529297
Current Learning_rate : 0.0001
Elapsed_time : 1:49:06.802992
}


valid_loss : 0.22310322523117065 / 0.18650305271148682
discriminator_loss : 0.02988801896572113
generator_discriminator_loss : 0.015857044607400894
train_f1_loss : 0.18113401532173157
valid_f1_score : 0.7822464108467102
current/min total_f1_loss = 0.4534240961074829 / 0.3676370680332184


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:50 ETA:   0:00:00{
2021-05-25 02:44:06,518 - train - INFO - 
Epoch : 5/200 (5)
Discriminator_acces : 0.041680569048879366
Mean generator loss : 15.302199363708496
Max generator loss : 28.38779067993164
Min generator loss : 9.950250625610352
Generator loss decrease : 4.058121681213379(0.7903897762298584)
Current lowest generator loss : 19.360321044921875
Current Learning_rate : 0.00012
Elapsed_time : 2:11:14.522749
}


valid_loss : 0.18650305271148682 / 0.1341729313135147
discriminator_loss : 0.01975739188492298
generator_discriminator_loss : 0.016246750950813293
train_f1_loss : 0.1422041952610016
valid_f1_score : 0.824862539768219
current/min total_f1_loss = 0.3676370680332184 / 0.2763771414756775


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:59 ETA:   0:00:00{
2021-05-25 03:06:22,793 - train - INFO - 
Epoch : 6/200 (6)
Discriminator_acces : 0.026743963911769263
Mean generator loss : 12.675069808959961
Max generator loss : 28.8616943359375
Min generator loss : 7.774558067321777
Generator loss decrease : 2.627129554748535(0.828316867351532)
Current lowest generator loss : 15.302199363708496
Current Learning_rate : 0.00014
Elapsed_time : 2:33:30.799242
}


valid_loss : 0.1341729313135147 / 0.11560408771038055
discriminator_loss : 0.018399400636553764
generator_discriminator_loss : 0.017990823835134506
train_f1_loss : 0.11673340946435928
valid_f1_score : 0.8253092765808105
current/min total_f1_loss = 0.2763771414756775 / 0.23233750462532043


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:01 ETA:   0:00:00{
2021-05-25 03:28:41,187 - train - INFO - 
Epoch : 7/200 (7)
Discriminator_acces : 0.011429725021302416
Mean generator loss : 10.617371559143066
Max generator loss : 24.355449676513672
Min generator loss : 6.315324306488037
Generator loss decrease : 2.0576982498168945(0.8376578092575073)
Current lowest generator loss : 12.675069808959961
Current Learning_rate : 0.00016
Elapsed_time : 2:55:49.192996
}


valid_loss : 0.11560408771038055 / 0.09453044086694717
discriminator_loss : 0.018455103039741516
generator_discriminator_loss : 0.019506538286805153
train_f1_loss : 0.09675264358520508
valid_f1_score : 0.838786780834198
current/min total_f1_loss = 0.23233750462532043 / 0.19128307700157166


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:06 ETA:   0:00:00{
2021-05-25 03:51:04,608 - train - INFO - 
Epoch : 8/200 (8)
Discriminator_acces : 0.004529877749141448
Mean generator loss : 9.06817626953125
Max generator loss : 39.547584533691406
Min generator loss : 4.891292572021484
Generator loss decrease : 1.5491952896118164(0.8540886044502258)
Current lowest generator loss : 10.617371559143066
Current Learning_rate : 0.00018
Elapsed_time : 3:18:12.613658
}


valid_loss : 0.09453044086694717 / 0.09337566792964935
discriminator_loss : 0.01867670565843582
generator_discriminator_loss : 0.018016405403614044
train_f1_loss : 0.0817929059267044
valid_f1_score : 0.8363083004951477
current/min total_f1_loss = 0.19128307700157166 / 0.17516857385635376


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:07 ETA:   0:00:00{
2021-05-25 04:13:28,500 - train - INFO - 
Epoch : 9/200 (9)
Discriminator_acces : 0.00156096687151415
Mean generator loss : 7.692677021026611
Max generator loss : 20.54585838317871
Min generator loss : 3.8367795944213867
Generator loss decrease : 1.3754992485046387(0.8483157753944397)
Current lowest generator loss : 9.06817626953125
Current Learning_rate : 0.0002
Elapsed_time : 3:40:36.506036
}


valid_loss : 0.09337566792964935 / 0.07019001245498657
discriminator_loss : 0.018362147733569145
generator_discriminator_loss : 0.015347701497375965
train_f1_loss : 0.06891079246997833
valid_f1_score : 0.8586054444313049
current/min total_f1_loss = 0.17516857385635376 / 0.1391008049249649


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:03 ETA:   0:00:00{
2021-05-25 04:35:49,158 - train - INFO - 
Epoch : 10/200 (10)
Discriminator_acces : 0.0
Mean generator loss : 6.955921649932861
Max generator loss : 22.73141098022461
Min generator loss : 3.0451669692993164
Generator loss decrease : 0.73675537109375(0.9042264223098755)
Current lowest generator loss : 7.692677021026611
Current Learning_rate : 0.00022000000000000003
Elapsed_time : 4:02:57.164155
}


valid_loss : 0.07019001245498657 / 0.06586799770593643
discriminator_loss : 0.01520868856459856
generator_discriminator_loss : 0.015819001942873
train_f1_loss : 0.06163215637207031
valid_f1_score : 0.8617302775382996
current/min total_f1_loss = 0.1391008049249649 / 0.12750014662742615


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:52 ETA:   0:00:00{
2021-05-25 04:57:58,590 - train - INFO - 
Epoch : 11/200 (11)
Discriminator_acces : 0.0
Mean generator loss : 6.065694808959961
Max generator loss : 19.772611618041992
Min generator loss : 2.426065683364868
Generator loss decrease : 0.8902268409729004(0.8720188736915588)
Current lowest generator loss : 6.955921649932861
Current Learning_rate : 0.00024
Elapsed_time : 4:25:06.596181
}


valid_loss : 0.06586799770593643 / 0.05059114098548889
discriminator_loss : 0.016402311623096466
generator_discriminator_loss : 0.01693112961947918
train_f1_loss : 0.05337156355381012
valid_f1_score : 0.8780003786087036
current/min total_f1_loss = 0.12750014662742615 / 0.10396270453929901


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:54 ETA:   0:00:00{
2021-05-25 05:20:09,690 - train - INFO - 
Epoch : 12/200 (12)
Discriminator_acces : 0.0
Mean generator loss : 5.6050944328308105
Max generator loss : 24.202049255371094
Min generator loss : 2.038116216659546
Generator loss decrease : 0.4606003761291504(0.9240646958351135)
Current lowest generator loss : 6.065694808959961
Current Learning_rate : 0.00026000000000000003
Elapsed_time : 4:47:17.695945
}


valid_loss : 0.05059114098548889 / 0.04999896511435509
discriminator_loss : 0.017408600077033043
generator_discriminator_loss : 0.017837999388575554
train_f1_loss : 0.048732656985521317
valid_f1_score : 0.8593195676803589
current/min total_f1_loss = 0.10396270453929901 / 0.0987316220998764


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:57 ETA:   0:00:00{
2021-05-25 05:42:24,353 - train - INFO - 
Epoch : 13/200 (13)
Discriminator_acces : 0.0
Mean generator loss : 5.197718143463135
Max generator loss : 19.346214294433594
Min generator loss : 1.6883213520050049
Generator loss decrease : 0.4073762893676758(0.9273203611373901)
Current lowest generator loss : 5.6050944328308105
Current Learning_rate : 0.00028
Elapsed_time : 5:09:32.359237
}


valid_loss : 0.04999896511435509 / 0.04331599548459053
discriminator_loss : 0.018212610855698586
generator_discriminator_loss : 0.01854214072227478
train_f1_loss : 0.045048389583826065
valid_f1_score : 0.8609709739685059
current/min total_f1_loss = 0.0987316220998764 / 0.0883643850684166


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:03 ETA:   0:00:00{
2021-05-25 06:04:44,306 - train - INFO - 
Epoch : 14/200 (14)
Discriminator_acces : 0.0
Mean generator loss : 5.014472961425781
Max generator loss : 24.533451080322266
Min generator loss : 1.4505391120910645
Generator loss decrease : 0.18324518203735352(0.964745044708252)
Current lowest generator loss : 5.197718143463135
Current Learning_rate : 0.00030000000000000003
Elapsed_time : 5:31:52.312262
}


valid_loss : 0.04331599548459053 / 0.03923119232058525
discriminator_loss : 0.0188214760273695
generator_discriminator_loss : 0.019065964967012405
train_f1_loss : 0.04335169121623039
valid_f1_score : 0.8658732175827026
current/min total_f1_loss = 0.0883643850684166 / 0.08258288353681564


N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

save weights


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:03 ETA:   0:00:00{
2021-05-25 06:27:04,706 - train - INFO - 
Epoch : 15/200 (15)
Discriminator_acces : 0.0
Mean generator loss : 4.623013973236084
Max generator loss : 18.605180740356445
Min generator loss : 1.2283093929290771
Generator loss decrease : 0.39145898818969727(0.921934187412262)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00032
Elapsed_time : 5:54:12.712435
}
N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

valid_loss : 0.03923119232058525 / 0.039316460490226746
discriminator_loss : 0.019261490553617477
generator_discriminator_loss : 0.019497813656926155
train_f1_loss : 0.03918106108903885
valid_f1_score : 0.8805503845214844
current/min total_f1_loss = 0.08258288353681564 / 0.0784975215792656


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:01 ETA:   0:00:00{
2021-05-25 06:49:21,934 - train - INFO - 
Epoch : 16/200 (16)
Discriminator_acces : 0.0
Mean generator loss : 4.5887885093688965
Max generator loss : 19.51127815246582
Min generator loss : 1.1665582656860352
Generator loss decrease : 0.42568445205688477(0.915108859539032)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00034
Elapsed_time : 6:16:29.939571
}


valid_loss : 0.03923119232058525 / 0.04689062386751175
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:59 ETA:   0:00:00{
2021-05-25 07:11:37,694 - train - INFO - 
Epoch : 17/200 (16)
Discriminator_acces : 0.0
Mean generator loss : 4.553504467010498
Max generator loss : 19.158157348632812
Min generator loss : 1.1643446683883667
Generator loss decrease : 0.4609684944152832(0.9080724120140076)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00034
Elapsed_time : 6:38:45.700202
}
N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

valid_loss : 0.03923119232058525 / 0.03966193273663521
discriminator_loss : 0.019272224977612495
generator_discriminator_loss : 0.01953153684735298
train_f1_loss : 0.038754310458898544
valid_f1_score : 0.8630224466323853
current/min total_f1_loss = 0.0784975215792656 / 0.07841624319553375


 99% (4840 of 4841) |################### | Elapsed Time: 0:21:59 ETA:   0:00:00{
2021-05-25 07:33:53,009 - train - INFO - 
Epoch : 18/200 (17)
Discriminator_acces : 0.0
Mean generator loss : 4.474262237548828
Max generator loss : 21.896692276000977
Min generator loss : 1.03338623046875
Generator loss decrease : 0.5402107238769531(0.8922696709632874)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00036
Elapsed_time : 7:01:01.014844
}


valid_loss : 0.03923119232058525 / 0.04442928358912468
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:03 ETA:   0:00:00{
2021-05-25 07:56:12,824 - train - INFO - 
Epoch : 19/200 (17)
Discriminator_acces : 0.0
Mean generator loss : 4.545515060424805
Max generator loss : 18.4688720703125
Min generator loss : 1.1667490005493164
Generator loss decrease : 0.46895790100097656(0.9064791202545166)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00036
Elapsed_time : 7:23:20.830387
}


valid_loss : 0.03923119232058525 / 0.049393050372600555
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:04 ETA:   0:00:00{
2021-05-25 08:18:34,282 - train - INFO - 
Epoch : 20/200 (17)
Discriminator_acces : 0.0
Mean generator loss : 4.628517150878906
Max generator loss : 20.06355094909668
Min generator loss : 1.1771266460418701
Generator loss decrease : 0.385955810546875(0.9230316281318665)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00036
Elapsed_time : 7:45:42.287558
}
N/A% (0 of 4841) |                       | Elapsed Time: 0:00:00 ETA:  --:--:--

valid_loss : 0.03923119232058525 / 0.040360480546951294
discriminator_loss : 0.019282732158899307
generator_discriminator_loss : 0.019426453858613968
train_f1_loss : 0.03938048705458641
valid_f1_score : 0.8800050020217896
current/min total_f1_loss = 0.07841624319553375 / 0.079740971326828


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:03 ETA:   0:00:00{
2021-05-25 08:40:53,767 - train - INFO - 
Epoch : 21/200 (18)
Discriminator_acces : 0.0
Mean generator loss : 4.481828212738037
Max generator loss : 23.06241798400879
Min generator loss : 1.0959899425506592
Generator loss decrease : 0.5326447486877441(0.8937785029411316)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00038
Elapsed_time : 8:08:01.771962
}


valid_loss : 0.03923119232058525 / 0.05044741556048393
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:24:24 ETA:   0:00:00{
2021-05-25 09:05:35,457 - train - INFO - 
Epoch : 22/200 (18)
Discriminator_acces : 0.0
Mean generator loss : 4.608340740203857
Max generator loss : 21.4058895111084
Min generator loss : 1.1663881540298462
Generator loss decrease : 0.40613222122192383(0.9190080165863037)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00038
Elapsed_time : 8:32:43.463462
}


valid_loss : 0.03923119232058525 / 0.06692881882190704
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:04 ETA:   0:00:00{
2021-05-25 09:27:57,408 - train - INFO - 
Epoch : 23/200 (18)
Discriminator_acces : 0.0
Mean generator loss : 4.772523403167725
Max generator loss : 18.412843704223633
Min generator loss : 1.2619774341583252
Generator loss decrease : 0.24194955825805664(0.9517497420310974)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00038
Elapsed_time : 8:55:05.414143
}


valid_loss : 0.03923119232058525 / 0.06253954023122787
loss decrease.


 99% (4840 of 4841) |################### | Elapsed Time: 0:22:02 ETA:   0:00:00{
2021-05-25 09:50:16,432 - train - INFO - 
Epoch : 24/200 (18)
Discriminator_acces : 0.0
Mean generator loss : 7.387022972106934
Max generator loss : 59.05607986450195
Min generator loss : 1.52970290184021
Generator loss decrease : -2.3725500106811523(1.4731404781341553)
Current lowest generator loss : 5.014472961425781
Current Learning_rate : 0.00038
Elapsed_time : 9:17:24.438278
}


valid_loss : 0.03923119232058525 / 0.11710887402296066
loss decrease.


  3% (188 of 4841) |                     | Elapsed Time: 0:00:52 ETA:   0:21:08ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "C:\Users\gr300\anaconda3\lib\site-packages\IPython\core\interactiveshell.py", line 3418, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-3-3ceeaa4d0883>", line 3, in <module>
    gan.train(epochs=200, batch_size=batch_size, epoch_shuffle_term=50)
  File "<ipython-input-1-e72b4841d7ec>", line 306, in train
    discriminator_loss = self.discriminator.train_on_batch([original_img, masked_img], valid_patch)
  File "C:\Users\gr300\anaconda3\lib\site-packages\tensorflow\python\keras\engine\training.py", line 1695, in train_on_batch
    logs = train_function(iterator)
  File "C:\Users\gr300\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 780, in __call__
    result = self._call(*args, **kwds)
  File "C:\Users\gr300\anaconda3\lib\site-packages\tensorflow\python\eager\def_function.py", line 807, in _call
    return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
  Fil

TypeError: object of type 'NoneType' has no len()

In [None]:
# gan.load_study_info()
# gan.start_epoch = 8
gan.train(epochs=200, batch_size=batch_size, epoch_shuffle_term=50)
# gan.train(epochs=20, batch_size=batch_size, epoch_shuffle_term=100)

In [None]:
gan.generator_losses[-4:]

In [None]:
gan.generator_losses

In [None]:
valid_patch = np.ones(
    (gan.batch_size, *gan.disc_patch), dtype=np.float32)

generator_loss = gan.combined.train_on_batch(
    gan.original_img,
    [valid_patch, gan.masked_img],
    class_weight={0: 0.1, 1: 0.9}
)

In [None]:
batch_i=0
batch_index = gan.loaded_data_index["train"][batch_i: batch_i +
                                              gan.batch_size]

In [None]:
np.mean(gan.generator_losses[batch_index]) > np.quantile(gan.generator_losses, 0.1)

In [None]:
# gan.load_study_info()
#gan.start_epoch = 30
# gan.start_epoch = 2 
gan.train(epochs=325, batch_size=batch_size, epoch_shuffle_term=100)
# gan.train(epochs=20, batch_size=batch_size, epoch_shuffle_term=100)

In [None]:
temp_weight = gan.generator.weights

In [None]:
for weight in temp_weight:
    if "dense" in weight.name:
        if "bias" in weight.name:
            print(weight.name)
            print(weight.numpy())


In [None]:
temp_weight = gan.discriminator.weights

In [None]:
for weight in temp_weight:
    if "dense" in weight.name:
        if "bias" in weight.name:
            print(weight.name)
            print(weight.numpy())


In [None]:
gan.original_img

In [None]:
np.max(gan.original_img.numpy())

In [None]:
np.min(gan.original_img.numpy())

In [None]:
(gan.original_img.numpy() + 1) * 127.5

In [None]:
image = gan.original_img.numpy()
predicted = gan.generator.predict_on_batch(image)
mask = gan.masked_img.numpy()

image = ((image + 1) * 127.5).astype('uint8')
predicted = ((predicted + 1) * 127.5).astype('uint8')
mask = ((mask + 1) * 127.5).astype('uint8')

In [None]:
index = 0

plt.figure(figsize=(15, 15))

plt.subplot(131)
plt.imshow(image[index])

plt.subplot(132)
plt.imshow(predicted[index])

plt.subplot(133)
plt.imshow(mask[index])


In [None]:
print(np.max(gan.original_img))
print(np.min(gan.original_img))
print(np.max(temp))
print(np.min(temp))

In [None]:
temp = gan.data_loader.loaded_data_object["train"].values()

for index, (input_img, output_img) in enumerate(zip(*temp)):
    print(index)
    if index > 40:
        break
    print(index)
    print(input_img.shape)
    print(output_img.shape)
    

In [None]:
import time

temp_source = gan.original_img
temp_mask = gan.masked_img

start_time = time.time()
gan.generator.train_on_batch(temp_source, temp_mask)
print(f"elapsed time : {time.time() - start_time}")

temp_source = tf.convert_to_tensor(temp_source)
temp_mask = tf.convert_to_tensor(temp_mask)

start_time = time.time()
gan.generator.train_on_batch(temp_source, temp_mask)
print(f"elapsed time : {time.time() - start_time}")

# Iterator : 260초
# Queue Iterator : 200초

In [None]:
import time
import threading
from queue import Queue

ITER_NUM = 620
batch_size = 10

gan.generator.compile(
    loss=sm.losses.BinaryFocalLoss(),
    optimizer=Nadam(gan.generator_learning_rate),
    metrics=["accuracy"],
)

def batch_setter(queue):
    batch_i = 0
    count = 0
    while batch_i + gan.batch_size <= gan.data_loader.train_data_length and count < ITER_NUM:
        
        batch_index = gan.train_loaded_data_index[batch_i: batch_i +
                                                   gan.batch_size]        
        
        batch_tuple = gan.data_loader.get_data(
        data_mode="train", index=batch_index)

        queue.put(batch_tuple)
        queue.join()
        count += 1
    
def batch_getter(queue):
    
    original_img, masked_img = queue.get()
    tensor_original_img = tf.convert_to_tensor(original_img)
    tensor_masked_img = tf.convert_to_tensor(masked_img)
    queue.task_done()
    
    return tensor_original_img, tensor_masked_img
    
def batch_trainer(original_img, masked_img):
    
    gan.generator.train_on_batch(temp_source, temp_mask)

q = Queue()

setter = threading.Thread(target=batch_setter, args=(q,),daemon=True)
setter.start()
start_time = time.time()
for i in range(ITER_NUM):
    tensor_original_img, tensor_masked_img = batch_getter(q)
    
    gan.generator.train_on_batch(tensor_original_img, tensor_masked_img)
print(f"elapsed time : {time.time() - start_time}")

In [None]:
start_time = time.time()
batch_i = 0
count = 0
while batch_i + gan.batch_size <= gan.data_loader.train_data_length and count < ITER_NUM:

    batch_index = gan.train_loaded_data_index[batch_i: batch_i +
                                               gan.batch_size]        
    batch_tuple = gan.data_loader.get_data(
    data_mode="train", index=batch_index)
    
    gan.generator.train_on_batch(*batch_tuple)
    
    count += 1
print(f"elapsed time : {time.time() - start_time}")

In [None]:
import cv2

temp = tensor_masked_img
print(type(temp))
print(type(temp.numpy()))

In [None]:
isinstance(temp.numpy(), tf.Tensor)

In [None]:
from gan_module.model.build_model import build_dual_discriminator

temp = build_dual_discriminator(
            input_img_shape=(512,512,3),
            output_img_shape=(512,512,1),
            discriminator_power=1,
)

In [None]:
import tensorflow as tf
from tensorflow.keras.optimizers import Nadam

temp.compile(
    loss=[
        tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),
        tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1)
    ],
    optimizer=Nadam(),
    metrics=["accuracy"],
)

In [None]:
import numpy as np

image_mockup = np.ones((1,512,512,3))
mask_mockup = np.ones((1,512,512,1))
patch_mockup = np.ones((1,8,8,1))

temp.test_on_batch([image_mockup, mask_mockup], [patch_mockup, patch_mockup])