In [4]:
# TBD 1 : logger 추가
# TBD 2: flask github 참고, method, class, 파일의 맨 윗단 마다 pydoc 형식으로 달기
# TBD 3: 축약어를 자제할것 (특히 변수)

# -------------------------
#   To-do
# -------------------------

# 0. add data-setter, receiver system use python queue.Queue() class
# this will resolve i/o bottleneck
# 1. add logger
# 2. make image drawer overlay mask on image
# 3. make iterable
# 4. make verbose turn on and off
# 5. write pydoc

# python basic Module
import os
import sys
import types
import progressbar
from datetime import datetime
from shutil import copy
from pickle import dump, load

# math, image, plot Module
import numpy as np
import cv2
import matplotlib.pyplot as plt  # TBD

# tensorflow Module
import tensorflow as tf
from tensorflow.keras import backend as keras_backend
from tensorflow.keras.layers import GaussianNoise
from tensorflow.keras.layers import Input, Concatenate
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Nadam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras import losses

# keras segmentaion third-party Moudle
import segmentation_models as sm

# custom Module
from gan_module.data_loader.medical_segmentation_data_loader import DataLoader
from gan_module.data_loader.manage_batch import BatchQueueManager

from gan_module.model.build_model import build_generator, build_discriminator
from gan_module.util.draw_images import ImageDrawer
from gan_module import custom_loss
from gan_module.custom_loss import f1_loss_for_training, f1_score, dice_loss_for_training
from gan_module.util.manage_learning_rate import learning_rate_scheduler
from gan_module.config import CONFIG

custom_loss.AXIS = [1, 2, 3]

USE_GPU = True

if USE_GPU:
    os.environ['CUDA_VISIBLE_DEVICES'] = '0'
    gpu_devices = tf.config.experimental.list_physical_devices("GPU")
    for device in gpu_devices:
        tf.config.experimental.set_memory_growth(device, True)
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

KERNEL_INITIALIZER = RandomNormal(mean=0.0, stddev=0.02)


class Pix2PixSegmentation:
    def __init__(
        self,
        generator_power=32,
        discriminator_power=32,
        generator_learning_rate=1e-4,
        discriminator_learning_rate=1e-4,
        temp_weights_path=".",
        draw_images=True,
        on_memory=True,
        code_test=False
    ):
        # Input shape
        img_shape = CONFIG["img_shape"]
        input_channels = CONFIG["input_channels"]
        output_channels = CONFIG["output_channels"]

        input_img_shape = (*img_shape, input_channels)
        output_img_shape = (*img_shape, output_channels)
        # set parameter
        self.start_epoch = None
        self.history = {"generator_loss": [],
                        "f1_loss_train": [], "f1_score_train": [],
                        "f1_loss_valid": [], "f1_score_valid": []}
        self.temp_weights_path = temp_weights_path

        # Configure data loader
        self.dataset_name = "tumor"
        self.data_loader = DataLoader(
            dataset_name=self.dataset_name,
            config_dict=CONFIG,
            on_memory=on_memory, code_test=code_test
        )

        self.loaded_data_index = {
            "train": np.arange(self.data_loader.data_length["train"]),
            "valid": np.arange(self.data_loader.data_length["valid"])
        }

        # Configure Image Drawer
        self.draw_images = draw_images
        self.image_drawer = ImageDrawer(
            dataset_name=self.dataset_name, data_loader=self.data_loader
        )
        self.discriminator_loss_ratio = keras_backend.variable(2.75)
        self.f1_loss_ratio = keras_backend.variable(97.25)
        self.discriminator_acc_previous = 0.5
        self.discriminator_acces = np.array(
            [0.5 for _ in range(self.data_loader.data_length["train"])])
        self.discriminator_acces_previous = self.discriminator_acces.copy()
        self.generator_losses = np.array(
            [1 for _ in range(self.data_loader.data_length["train"])], dtype=np.float32)
        self.generator_losses_previous = self.generator_losses.copy()
        self.generator_loss_min = 500
        self.generator_loss_previous = 100
        self.generator_loss_max_previous = 1000
        self.generator_loss_max_min = 1000
        self.generator_loss_min_min = 1000
        self.weight_save_stack = False
        self.training_end_stack = 0
        # Calculate output shape of D (PatchGAN)
        patch = 2 ** 3
        self.disc_patch = (patch, patch, 1)
        # Number of filters in the first layer of G and D
        self.generator_learning_rate = generator_learning_rate
        self.discriminator_learning_rate = discriminator_learning_rate
        generator_optimizer = Nadam(self.generator_learning_rate)
        discriminator_optimizer = Nadam(self.discriminator_learning_rate)
        
        # Build the generator
        self.generator = build_generator(
            input_img_shape=input_img_shape,
            output_channels=output_channels,
            generator_power=generator_power,
            kernel_initializer=KERNEL_INITIALIZER,
        )
        # Build and compile the discriminator
        self.discriminator = build_discriminator(
            input_img_shape=input_img_shape,
            output_img_shape=output_img_shape,
            discriminator_power=discriminator_power,
            kernel_initializer=KERNEL_INITIALIZER,
        )
        # self.discriminator = self.build_discriminator()
        # 'mse' or tf.keras.losses.Huber() tf.keras.losses.LogCosh()
        self.discriminator.compile(
            loss=tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),
            optimizer=discriminator_optimizer,
            metrics=["accuracy"],
        )

        # -------------------------
        # Construct Computational
        #   Graph of Generator
        # -------------------------

        # Input images and their conditioning images
        original_img = Input(shape=input_img_shape)
        masked_img = Input(shape=output_img_shape)
        # generate image from original_img for target masked_img
        model_masked_img = self.generator(original_img)

        # For the combined model we will only train the generator
        self.discriminator.trainable = False
        # Discriminators determines validity of translated images / condition pairs
        model_validity = self.discriminator([original_img, model_masked_img])
        # give score by
        # 1. how generator trick discriminator
        # 2. how generator's image same as real photo in pixel
        # 3. if you want change loss, see doc https://keras.io/api/losses/
        # 4. 'mse', 'mae', tf.keras.losses.LogCosh(),  tf.keras.losses.Huber()
        self.combined = Model(
            inputs=[original_img, masked_img],
            outputs=[model_validity, model_masked_img],
        )
        
        self.combined.compile(
            loss=[
                tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),
                dice_loss_for_training
            ],
            loss_weights=[2.75, 97.25],
            optimizer=generator_optimizer,
        )

    def train(self, epochs, batch_size=1, epoch_shuffle_term=10):

        start_time = datetime.now()

        # Adversarial loss ground truths
        self.training_end_stack = 0
        self.batch_size = batch_size
        valid_patch = np.ones(
            (self.batch_size, *self.disc_patch), dtype=np.float32)
        fake_patch = np.zeros(
            (self.batch_size, *self.disc_patch), dtype=np.float32)
        # TBD : move batch_queue_manager to __init__
        self.batch_queue_manager = BatchQueueManager(self)

        if self.start_epoch is None:
            self.start_epoch = 0
        for epoch in range(self.start_epoch, epochs):
            batch_i = 0

            discriminator_losses = []
            generator_loss_max_in_epoch = 0
            generator_loss_min_in_epoch = 1000

            # shffle data maybe
            if epoch % epoch_shuffle_term == 0:
                np.random.shuffle(self.loaded_data_index["train"])

            if self.discriminator_acc_previous < 0.75:
                discriminator_learning = True
                print("discriminator_learning is True")
            else:
                discriminator_learning = False
                print("discriminator_learning is False")
                
            generator_current_learning_rate = learning_rate_scheduler(
                self.generator_learning_rate,
                epoch,
            )
            discriminator_current_learning_rate = learning_rate_scheduler(
                self.discriminator_learning_rate,
                epoch,
            ) * (1 - self.discriminator_acc_previous)
            keras_backend.set_value(
                self.discriminator.optimizer.learning_rate,
                discriminator_current_learning_rate,
            )
            keras_backend.set_value(
                self.discriminator.optimizer.learning_rate,
                discriminator_current_learning_rate,
            )
            keras_backend.set_value(
                self.discriminator_loss_ratio,
                keras_backend.variable(0.5) + 4.5 * self.discriminator_acc_previous,
            )
            keras_backend.set_value(
                self.f1_loss_ratio,
                keras_backend.variable(99.5) - 4.5 * self.discriminator_acc_previous
            )
            bar = progressbar.ProgressBar(
                maxval=self.data_loader.data_length["train"]).start()
            
            while batch_i + self.batch_size <= self.data_loader.data_length["train"]:
                bar.update(batch_i)

                batch_index = self.loaded_data_index["train"][batch_i: batch_i +
                                                              self.batch_size]

                original_img, masked_img = self.batch_queue_manager.get_batch(
                    data_mode="train")
                model_masked_img = self.generator.predict_on_batch(
                    original_img)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # Train Discriminator for valid image if it failed to detect fake image

                if discriminator_learning:
                    self.discriminator.train_on_batch(
                        [original_img, masked_img], valid_patch)

                batch_discriminator_acc_previous = np.mean(
                    self.discriminator_acces_previous[batch_index])

                # -----------------
                #  Train Generator
                # -----------------

                # Train the generators

                generator_loss = self.combined.train_on_batch(
                    [original_img, masked_img],
                    [valid_patch, masked_img]
                )
                # train discriminator for fake image if it failed to detect fake image
                if (batch_discriminator_acc_previous <= 0.5 or epoch == 0) and discriminator_learning:
                    discriminator_loss = self.discriminator.train_on_batch(
                        [original_img, model_masked_img], fake_patch)
                else:
                    discriminator_loss = self.discriminator.test_on_batch(
                        [original_img, model_masked_img], fake_patch)

                self.discriminator_acces[batch_index] = discriminator_loss[1]
                self.generator_losses[batch_index] = generator_loss[0]
                

                # 한 배치 끝
                batch_i += self.batch_size
            # training batch 사이클 끝
            print(f"epoch: {epoch}/{epochs}")
            print(f"discriminator_acces : {np.mean(self.discriminator_acces)}")
            print(
                f"Mean generator_loss : {np.mean(self.generator_losses)}")
            print(f"Max generator_loss : {np.max(self.generator_losses)}")
            print(f"Min generator_loss : {np.min(self.generator_losses)}")
            print(
                f"generator loss decrease : {self.generator_loss_min - np.mean(self.generator_losses)}"
            )
            print(
                f"generator loss decrease ratio : ({np.mean(self.generator_losses) / self.generator_loss_min})"
            )
            print(
                f"Max generator loss decrease : {self.generator_loss_max_previous - np.max(self.generator_losses)}"
            )
            print(
                f"current lowest generator loss : {self.generator_loss_min}")
            print(
                f"current Learning_rate : {generator_current_learning_rate}")
            print(f"elapsed_time : {datetime.now() - start_time}")
            self.image_drawer.sample_images(
                self.generator, epoch)
            
            if np.mean(self.generator_losses) / self.generator_loss_min < 1.1:
                if self.generator_loss_min > np.mean(self.generator_losses):
                    self.generator_loss_min = np.mean(self.generator_losses)
                    self.generator_loss_max_min = generator_loss_max_in_epoch
                    self.generator_loss_min_min = generator_loss_min_in_epoch
                    self.weight_save_stack = True
                    self.save_study_info()
                    print("save weights")

                train_f1_loss_list = []
                train_f1_score_list = []
                for index in range(0, self.data_loader.data_length["train"], self.batch_size):

                    train_source_img, train_masked_img = self.batch_queue_manager.get_batch(
                        data_mode="train")

                    train_model_masked_img = self.generator.predict_on_batch(
                        train_source_img)

                    train_f1_loss = f1_loss_for_training(
                        train_masked_img, train_model_masked_img)
                    train_f1_score = f1_score(
                        train_masked_img, train_model_masked_img)
                    train_f1_loss_list.append(train_f1_loss)
                    train_f1_score_list.append(train_f1_score)

                print(
                    f"train_f1_loss : {np.mean(train_f1_loss_list) * self.f1_loss_ratio}")
                print(f"train_f1_score : {1 - np.mean(train_f1_loss_list)}")
                print(
                    f"train_f1_rounded_score : {np.mean(train_f1_score_list)}")

                valid_f1_loss_list = []
                valid_f1_score_list = []
                for index in range(0, self.data_loader.data_length["valid"], self.batch_size):

                    valid_source_img, valid_masked_img = self.batch_queue_manager.get_batch(
                        data_mode="valid")

                    valid_model_masked_img = self.generator.predict_on_batch(
                        valid_source_img)

                    valid_f1_loss = f1_loss_for_training(
                        valid_masked_img, valid_model_masked_img)
                    valid_f1_score = f1_score(
                        valid_masked_img, valid_model_masked_img)
                    valid_f1_loss_list.append(valid_f1_loss)
                    valid_f1_score_list.append(valid_f1_score)

                print(
                    f"valid_f1_loss : {np.mean(valid_f1_loss_list) * self.f1_loss_ratio}")
                print(f"valid_f1_score : {1 - np.mean(valid_f1_loss_list)}")
                print(
                    f"valid_f1_rounded_score : {np.mean(valid_f1_score_list)}")
            else:
                print("loss decrease.")
                self.load_best_weights()

            # previous generator_loss 갱신
            self.generator_loss_previous = np.mean(self.generator_losses)
            self.generator_loss_max_previous = generator_loss_max_in_epoch

            if epoch >= 10 and self.weight_save_stack:
                copy(
                    "generator.h5",
                    "./generator_weights/generator_"
                    + str(round(self.generator_loss_min, 5))
                    + "_"
                    + str(round(self.generator_loss_max_min, 5))
                    + ".h5",
                )
                self.weight_save_stack = False

            self.discriminator_acc_previous = np.mean(self.discriminator_acces)
            self.discriminator_acces_previous = self.discriminator_acces.copy()
            self.generator_losses_previous = self.generator_losses.copy()
            # TBD: add epoch bigger than history length
            self.history["generator_loss"].append(
                np.mean(self.generator_losses))
            self.history["f1_loss_train"].append(
                np.mean(train_f1_loss_list))
            self.history["f1_score_train"].append(
                np.mean(train_f1_score_list))
            self.history["f1_loss_valid"].append(
                np.mean(valid_f1_loss_list))
            self.history["f1_score_valid"].append(
                np.mean(valid_f1_score_list))

    def get_info_folderPath(self):
        return (
            str(round(self.generator_loss_min, 5))
            + "_"
            + str(round(self.generator_loss_max_min, 5))
        )

    def save_study_info(self, path=None):

        if path is None:
            path = self.temp_weights_path

        generator_weigth_path = os.path.join(path, "generator.h5")
        discriminator_weigth_path = os.path.join(path, "discriminator.h5")
        combined_weigth_path = os.path.join(path, "combined.h5")

        self.generator.save_weights(generator_weigth_path)
        self.discriminator.save_weights(discriminator_weigth_path)
        self.combined.save_weights(combined_weigth_path)

        study_info = {}
        study_info["start_epoch"] = self.start_epoch
        study_info["train_loaded_data_index"] = self.loaded_data_index["train"]
        study_info["generator_loss_min"] = self.generator_loss_min
        study_info["generator_loss_max_min"] = self.generator_loss_max_min
        study_info["generator_loss_min_min"] = self.generator_loss_min_min
        study_info["generator_losses_previous"] = self.generator_losses_previous
        study_info["discriminator_acces"] = self.discriminator_acces
        study_info["history"] = self.history
        file = open(path + "/study_info.pkl", "wb")
        dump(study_info, file)
        file.close()

    def load_study_info(self):

        self.generator.load_weights("generator.h5")
        self.discriminator.load_weights("discriminator.h5")
        self.combined.load_weights("combined.h5")

        if os.path.isfile("study_info.pkl"):
            file = open("study_info.pkl", "rb")
            study_info = load(file)
            file.close()
            self.start_epoch = study_info["start_epoch"]
            self.loaded_data_index["train"] = study_info["train_loaded_data_index"]
            self.generator_loss_min = study_info["generator_loss_min"]
            self.generator_loss_max_min = study_info["generator_loss_max_min"]
            self.generator_loss_min_min = study_info["generator_loss_min_min"]
            self.generator_losses_previous = study_info["generator_losses_previous"]
            self.discriminator_acces = study_info["discriminator_acces"]
            self.history = study_info["history"]
        else:
            print("No info pkl file!")

    def load_best_weights(self):
        self.generator.load_weights(self.temp_weights_path + "/generator.h5")
        self.discriminator.load_weights(
            self.temp_weights_path + "/discriminator.h5")
        self.combined.load_weights(self.temp_weights_path + "/combined.h5")


In [5]:
generator_lr = 1e-3
discriminator_lr = 5e-4
batch_size = 4
g_lr = generator_lr * batch_size
d_lr = discriminator_lr * batch_size
gan = Pix2PixSegmentation(generator_power=4, discriminator_power=4, 
                          generator_learning_rate=g_lr, discriminator_learning_rate=d_lr,
                          on_memory=True, code_test=False, draw_images=False)

In [None]:
#gan.load_study_info()
gan.train(epochs=325, batch_size=batch_size, epoch_shuffle_term=50)

  0% |                                                                        |

discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 74.94190216064453
Max generator_loss : 98.4444351196289
Min generator_loss : 33.604515075683594
generator loss decrease : 425.05809783935547
generator loss decrease ratio : (0.14988380432128906)
Max generator loss decrease : 901.5555648803711
current lowest generator loss : 500
current Learning_rate = 0.0004
elapsed_time = 0:14:21.545612
save weights
train_f1_loss : 69.32411193847656
train_f1_score : 0.2871556878089905
train_f1_rounded_score : 0.45159637928009033


  0% |                                                                        |

valid_f1_loss : 70.49111938476562
valid_f1_score : 0.27515560388565063
valid_f1_rounded_score : 0.4749564230442047
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.6967741935483871
Mean generator_loss : 65.20829772949219
Max generator_loss : 97.87315368652344
Min generator_loss : 29.088171005249023
generator loss decrease : 9.733604431152344
generator loss decrease ratio : (0.8701180219650269)
Max generator loss decrease : -97.87315368652344
current lowest generator loss : 74.94190216064453
current Learning_rate = 0.0008
elapsed_time = 0:30:00.369420
save weights
train_f1_loss : 57.73124694824219
train_f1_score : 0.4197864532470703
train_f1_rounded_score : 0.5426164865493774


  0% |                                                                        |

valid_f1_loss : 53.79901123046875
valid_f1_score : 0.45930641889572144
valid_f1_rounded_score : 0.5908359885215759
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.30709677419354836
Mean generator_loss : 53.720401763916016
Max generator_loss : 95.91558074951172
Min generator_loss : 21.729328155517578
generator loss decrease : 11.487895965576172
generator loss decrease ratio : (0.8238276839256287)
Max generator loss decrease : -95.91558074951172
current lowest generator loss : 65.20829772949219
current Learning_rate = 0.0012
elapsed_time = 0:44:26.378938
save weights
train_f1_loss : 54.547996520996094
train_f1_score : 0.4339410662651062
train_f1_rounded_score : 0.5150406360626221


  0% |                                                                        |

valid_f1_loss : 50.20939254760742
valid_f1_score : 0.4789639115333557
valid_f1_rounded_score : 0.5612741708755493
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 46.4517707824707
Max generator_loss : 91.82311248779297
Min generator_loss : 18.01472282409668
generator loss decrease : 7.2686309814453125
generator loss decrease ratio : (0.864695131778717)
Max generator loss decrease : -91.82311248779297
current lowest generator loss : 53.720401763916016
current Learning_rate = 0.0016
elapsed_time = 0:59:30.068289
save weights
train_f1_loss : 46.740234375
train_f1_score : 0.5236327350139618
train_f1_rounded_score : 0.5337313413619995


  0% |                                                                        |

valid_f1_loss : 49.28897476196289
valid_f1_score : 0.4976564645767212
valid_f1_rounded_score : 0.5075386166572571
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 43.04558563232422
Max generator_loss : 88.73362731933594
Min generator_loss : 16.87168312072754
generator loss decrease : 3.4061851501464844
generator loss decrease ratio : (0.926672637462616)
Max generator loss decrease : -88.73362731933594
current lowest generator loss : 46.4517707824707
current Learning_rate = 0.002
elapsed_time = 1:15:06.478829
save weights
train_f1_loss : 44.0322380065918
train_f1_score : 0.5574649572372437
train_f1_rounded_score : 0.5695021748542786


  0% |                                                                        |

valid_f1_loss : 43.02777862548828
valid_f1_score : 0.5675600171089172
valid_f1_rounded_score : 0.5796874761581421
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 41.48855972290039
Max generator_loss : 85.21238708496094
Min generator_loss : 17.278059005737305
generator loss decrease : 1.5570259094238281
generator loss decrease ratio : (0.963828444480896)
Max generator loss decrease : -85.21238708496094
current lowest generator loss : 43.04558563232422
current Learning_rate = 0.0024
elapsed_time = 1:30:41.999366
save weights
train_f1_loss : 40.948177337646484
train_f1_score : 0.5884605348110199
train_f1_rounded_score : 0.6027809381484985


  0% |                                                                        |

valid_f1_loss : 38.86347961425781
valid_f1_score : 0.609412282705307
valid_f1_rounded_score : 0.6236271262168884
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 40.5056266784668
Max generator_loss : 84.1563491821289
Min generator_loss : 16.4763126373291
generator loss decrease : 0.9829330444335938
generator loss decrease ratio : (0.9763083457946777)
Max generator loss decrease : -84.1563491821289
current lowest generator loss : 41.48855972290039
current Learning_rate = 0.0028
elapsed_time = 1:46:12.435424
save weights
train_f1_loss : 37.40164566040039
train_f1_score : 0.6241040527820587
train_f1_rounded_score : 0.6315369009971619


  0% |                                                                        |

valid_f1_loss : 37.47148513793945
valid_f1_score : 0.6234021782875061
valid_f1_rounded_score : 0.6314181685447693
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.46
Mean generator_loss : 40.062278747558594
Max generator_loss : 83.97616577148438
Min generator_loss : 16.56574058532715
generator loss decrease : 0.4433479309082031
generator loss decrease ratio : (0.9890546798706055)
Max generator loss decrease : -83.97616577148438
current lowest generator loss : 40.5056266784668
current Learning_rate = 0.0032
elapsed_time = 2:01:48.861369
save weights
train_f1_loss : 38.89118957519531
train_f1_score : 0.609133780002594
train_f1_rounded_score : 0.6137192845344543


  0% |                                                                        |

valid_f1_loss : 39.149295806884766
valid_f1_score : 0.6065397560596466
valid_f1_rounded_score : 0.6119461059570312
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.5432258064516129
Mean generator_loss : 39.23699188232422
Max generator_loss : 82.61573028564453
Min generator_loss : 16.774850845336914
generator loss decrease : 0.825286865234375
generator loss decrease ratio : (0.9793999195098877)
Max generator loss decrease : -82.61573028564453
current lowest generator loss : 40.062278747558594
current Learning_rate = 0.0036000000000000003
elapsed_time = 2:16:36.847722
save weights
train_f1_loss : 37.114871978759766
train_f1_score : 0.6190611720085144
train_f1_rounded_score : 0.6222674250602722


  0% |                                                                        |

valid_f1_loss : 36.333648681640625
valid_f1_score : 0.6270794570446014
valid_f1_rounded_score : 0.6306325197219849
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 38.10550308227539
Max generator_loss : 83.02185821533203
Min generator_loss : 15.221029281616211
generator loss decrease : 1.1314888000488281
generator loss decrease ratio : (0.9711626768112183)
Max generator loss decrease : -83.02185821533203
current lowest generator loss : 39.23699188232422
current Learning_rate = 0.004
elapsed_time = 2:31:17.039243
save weights
train_f1_loss : 42.891300201416016
train_f1_score : 0.558074414730072
train_f1_rounded_score : 0.5607362389564514


  0% |                                                                        |

valid_f1_loss : 43.48058319091797
valid_f1_score : 0.5520028173923492
valid_f1_rounded_score : 0.5550512671470642
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 38.84111404418945
Max generator_loss : 88.78291320800781
Min generator_loss : 14.766106605529785
generator loss decrease : -0.7356109619140625
generator loss decrease ratio : (1.019304633140564)
Max generator loss decrease : -88.78291320800781
current lowest generator loss : 38.10550308227539
current Learning_rate = 0.004
elapsed_time = 2:46:56.403585
train_f1_loss : 38.48203659057617
train_f1_score : 0.6132458746433258
train_f1_rounded_score : 0.6148419976234436


  0% |                                                                        |

valid_f1_loss : 36.97917175292969
valid_f1_score : 0.6283500492572784
valid_f1_rounded_score : 0.6303502321243286
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 38.49127197265625
Max generator_loss : 81.8115005493164
Min generator_loss : 16.080625534057617
generator loss decrease : -0.3857688903808594
generator loss decrease ratio : (1.0101237297058105)
Max generator loss decrease : -81.8115005493164
current lowest generator loss : 38.10550308227539
current Learning_rate = 0.004
elapsed_time = 3:02:33.339622
train_f1_loss : 36.887577056884766
train_f1_score : 0.6292705833911896
train_f1_rounded_score : 0.6358424425125122


  0% |                                                                        |

valid_f1_loss : 36.565513610839844
valid_f1_score : 0.6325074136257172
valid_f1_rounded_score : 0.6394422650337219
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.5496774193548387
Mean generator_loss : 38.46792221069336
Max generator_loss : 81.98648071289062
Min generator_loss : 14.758145332336426
generator loss decrease : -0.36241912841796875
generator loss decrease ratio : (1.0095109939575195)
Max generator loss decrease : -81.98648071289062
current lowest generator loss : 38.10550308227539
current Learning_rate = 0.004
elapsed_time = 3:18:22.577530
train_f1_loss : 36.89860916137695
train_f1_score : 0.629159688949585
train_f1_rounded_score : 0.6312313079833984


  0% |                                                                        |

valid_f1_loss : 37.088924407958984
valid_f1_score : 0.6272470057010651
valid_f1_rounded_score : 0.629781186580658
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.4541935483870968
Mean generator_loss : 37.730690002441406
Max generator_loss : 82.279296875
Min generator_loss : 13.89442253112793
generator loss decrease : 0.3748130798339844
generator loss decrease ratio : (0.9901638031005859)
Max generator loss decrease : -82.279296875
current lowest generator loss : 38.10550308227539
current Learning_rate = 0.004
elapsed_time = 3:32:57.758637
save weights
train_f1_loss : 37.67360305786133
train_f1_score : 0.6117182075977325
train_f1_rounded_score : 0.6173651814460754


  0% |                                                                        |

valid_f1_loss : 38.531890869140625
valid_f1_score : 0.6028723120689392
valid_f1_rounded_score : 0.609767735004425
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 37.06138610839844
Max generator_loss : 79.67748260498047
Min generator_loss : 12.878314971923828
generator loss decrease : 0.6693038940429688
generator loss decrease ratio : (0.9822610020637512)
Max generator loss decrease : -79.67748260498047
current lowest generator loss : 37.730690002441406
current Learning_rate = 0.004
elapsed_time = 3:47:45.474044
save weights
train_f1_loss : 37.17893600463867
train_f1_score : 0.6185059249401093
train_f1_rounded_score : 0.6218627691268921


  0% |                                                                        |

valid_f1_loss : 35.476722717285156
valid_f1_score : 0.6359723806381226
valid_f1_rounded_score : 0.6396198868751526
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 37.77935028076172
Max generator_loss : 82.53418731689453
Min generator_loss : 13.623322486877441
generator loss decrease : -0.7179641723632812
generator loss decrease ratio : (1.0193723440170288)
Max generator loss decrease : -82.53418731689453
current lowest generator loss : 37.06138610839844
current Learning_rate = 0.004
elapsed_time = 4:03:16.491581
train_f1_loss : 38.15963363647461
train_f1_score : 0.6164861023426056
train_f1_rounded_score : 0.61678546667099


  0% |                                                                        |

valid_f1_loss : 34.76260757446289
valid_f1_score : 0.6506270468235016
valid_f1_rounded_score : 0.6513795852661133
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.0
Mean generator_loss : 37.814212799072266
Max generator_loss : 80.75830078125
Min generator_loss : 14.256397247314453
generator loss decrease : -0.7528266906738281
generator loss decrease ratio : (1.020313024520874)
Max generator loss decrease : -80.75830078125
current lowest generator loss : 37.06138610839844
current Learning_rate = 0.004
elapsed_time = 4:18:45.328652
train_f1_loss : 35.217227935791016
train_f1_score : 0.6460579931735992
train_f1_rounded_score : 0.6477833390235901


  0% |                                                                        |

valid_f1_loss : 35.145912170410156
valid_f1_score : 0.6467747390270233
valid_f1_rounded_score : 0.6487274169921875
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.6432258064516129
Mean generator_loss : 37.7579231262207
Max generator_loss : 80.37793731689453
Min generator_loss : 14.923700332641602
generator loss decrease : -0.6965370178222656
generator loss decrease ratio : (1.0187941789627075)
Max generator loss decrease : -80.37793731689453
current lowest generator loss : 37.06138610839844
current Learning_rate = 0.004
elapsed_time = 4:34:13.908293
train_f1_loss : 37.05307388305664
train_f1_score : 0.6276073157787323
train_f1_rounded_score : 0.6301535964012146


  0% |                                                                        |

valid_f1_loss : 36.20487594604492
valid_f1_score : 0.6361319124698639
valid_f1_rounded_score : 0.6391090154647827
discriminator_learning is True


 99% |####################################################################### |

discriminator_acces : 0.36129032258064514
Mean generator_loss : 37.18083190917969
Max generator_loss : 81.75773620605469
Min generator_loss : 13.317173957824707
generator loss decrease : -0.11944580078125
generator loss decrease ratio : (1.003222942352295)
Max generator loss decrease : -81.75773620605469
current lowest generator loss : 37.06138610839844
current Learning_rate = 0.004
elapsed_time = 4:48:39.588826
train_f1_loss : 37.68098068237305
train_f1_score : 0.6099498867988586
train_f1_rounded_score : 0.611587405204773


  0% |                                                                        |

valid_f1_loss : 37.35661697387695
valid_f1_score : 0.6133075058460236
valid_f1_rounded_score : 0.6152362823486328
discriminator_learning is True


 54% |#######################################                                 |

In [5]:
original_img = np.ones((1,512,512,3))
masked_img = np.ones((1,512,512,1))
valid_patch = np.ones((1,8,8,1))

gan.combined.test_on_batch(
    [original_img, masked_img],
    [valid_patch, masked_img],
)

[60.188560485839844, 0.6891670823097229, 0.5914608836174011]

In [9]:
gan.f1_loss_ratio

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=100.0>

In [10]:
gan.discriminator_loss_ratio

<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.5>

In [None]:
[60.188560485839844, 0.6891670823097229, 0.5914608836174011]

In [13]:
0.6891670823097229 * 0.5 + 0.5914608836174011 * 99.5

59.19494146108627

In [1]:
import tensorflow as tf

print(tf.__version__)

2.3.1


In [35]:
gan.discriminator.optimizer.

SyntaxError: invalid syntax (<ipython-input-35-f92eaba7459e>, line 1)

In [5]:
import time

temp_source = gan.original_img
temp_mask = gan.masked_img

start_time = time.time()
gan.generator.train_on_batch(temp_source, temp_mask)
print(f"elapsed time : {time.time() - start_time}")

temp_source = tf.convert_to_tensor(temp_source)
temp_mask = tf.convert_to_tensor(temp_mask)

start_time = time.time()
gan.generator.train_on_batch(temp_source, temp_mask)
print(f"elapsed time : {time.time() - start_time}")





elapsed time : 4.586500883102417
elapsed time : 0.29304075241088867


In [10]:
batch_i = 0

while batch_i + gan.batch_size <= gan.data_loader.train_data_length:

    batch_index = gan.train_loaded_data_index[batch_i: batch_i +
                                               gan.batch_size]
    original_img, masked_img = gan.data_loader.get_data(
        data_mode="train", index=batch_index)
    

KeyboardInterrupt: 

In [4]:
gan.data_loader.train_data_length

20

# Iterator : 260초
# Queue Iterator : 200초

In [5]:
import time
import threading
from queue import Queue

ITER_NUM = 620
batch_size = 10

gan.generator.compile(
    loss=sm.losses.BinaryFocalLoss(),
    optimizer=Nadam(gan.generator_learning_rate),
    metrics=["accuracy"],
)

def batch_setter(queue):
    batch_i = 0
    count = 0
    while batch_i + gan.batch_size <= gan.data_loader.train_data_length and count < ITER_NUM:
        
        batch_index = gan.train_loaded_data_index[batch_i: batch_i +
                                                   gan.batch_size]        
        
        batch_tuple = gan.data_loader.get_data(
        data_mode="train", index=batch_index)

        queue.put(batch_tuple)
        queue.join()
        count += 1
    
def batch_getter(queue):
    
    original_img, masked_img = queue.get()
    tensor_original_img = tf.convert_to_tensor(original_img)
    tensor_masked_img = tf.convert_to_tensor(masked_img)
    queue.task_done()
    
    return tensor_original_img, tensor_masked_img
    
def batch_trainer(original_img, masked_img):
    
    gan.generator.train_on_batch(temp_source, temp_mask)

q = Queue()

setter = threading.Thread(target=batch_setter, args=(q,),daemon=True)
setter.start()
start_time = time.time()
for i in range(ITER_NUM):
    tensor_original_img, tensor_masked_img = batch_getter(q)
    
    gan.generator.train_on_batch(tensor_original_img, tensor_masked_img)
print(f"elapsed time : {time.time() - start_time}")

elapsed time : 204.91554951667786


In [29]:
start_time = time.time()
batch_i = 0
count = 0
while batch_i + gan.batch_size <= gan.data_loader.train_data_length and count < ITER_NUM:

    batch_index = gan.train_loaded_data_index[batch_i: batch_i +
                                               gan.batch_size]        
    batch_tuple = gan.data_loader.get_data(
    data_mode="train", index=batch_index)
    
    gan.generator.train_on_batch(*batch_tuple)
    
    count += 1
print(f"elapsed time : {time.time() - start_time}")

elapsed time : 260.5594081878662


In [10]:
import cv2

temp = tensor_masked_img
print(type(temp))
print(type(temp.numpy()))

<class 'tensorflow.python.framework.ops.EagerTensor'>
<class 'numpy.ndarray'>


In [11]:
isinstance(temp.numpy(), tf.Tensor)

False

In [1]:
from gan_module.model.build_model import build_dual_discriminator

temp = build_dual_discriminator(
            input_img_shape=(512,512,3),
            output_img_shape=(512,512,1),
            discriminator_power=1,
)

In [2]:
import tensorflow as tf
from tensorflow.keras.optimizers import Nadam

temp.compile(
    loss=[
        tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1),
        tf.keras.losses.BinaryCrossentropy(label_smoothing=0.1)
    ],
    optimizer=Nadam(),
    metrics=["accuracy"],
)

In [3]:
import numpy as np

image_mockup = np.ones((1,512,512,3))
mask_mockup = np.ones((1,512,512,1))
patch_mockup = np.ones((1,8,8,1))

temp.test_on_batch([image_mockup, mask_mockup], [patch_mockup, patch_mockup])

UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node functional_1/conv2d_34/Conv2D (defined at <ipython-input-3-d8b1761e412e>:7) ]] [Op:__inference_test_function_13463]

Function call stack:
test_function
