# generator

In [None]:
def generator(loss, optimizer, metrics):
  """
  --------------------------------------------
  함수 설명
  --------------------------------------------
  input : resizing된 사진
  output : 카툰화된 사진
  """

  # ----------------------------------------------------------------------------------------------
  def generator_residual_block(x):
    """
    --------------------------------------------
    함수 설명
    --------------------------------------------
    residual block의 똑같은 구조가 generator안에 8번 반복되므로 따로 만듦.
    --------------------------------------------
    """
    shortcut = x
    x = Conv2D(kernel_size = 3,
              filters = 256,
              strides = 1,
              padding = "same"
              )(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = Conv2D(kernel_size = 3,
              filters = 256,
              strides = 1,
              padding = "same")(x)
    x = BatchNormalization()(x)
    x = layers.Add()([x, shortcut])  # identity shortcut connection  # elementwise sum
    
    return x
  
  # ----------------------------------------------------------------------------------------------
  # flat-convolution 영역
  input_shape = (300, 300, 3)
  input_layer = Input(shape = input_shape)
  net = Conv2D(kernel_size = 7,
               filters = 64,
               strides = 1,
               padding = "same"
               )(input_layer)
  net = BatchNormalization()(net)
  net = ReLU()(net)
  
  # down-convolution 영역
  net = Conv2D(kernel_size = 3,
              filters = 128,
              strides = 2,
              padding = "same"
              )(net)

  net = Conv2D(kernel_size = 3,
                filters = 128,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = ReLU()(net)

  net = Conv2D(kernel_size = 3,
                filters = 256,
                strides = 2,
                padding = "same"
                )(net)
  net = Conv2D(kernel_size = 3,
                filters = 256,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = ReLU()(net)

  # 8 residual block 영역
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)
  net = generator_residual_block(net)

  # up-convolution 영역
  net = Conv2DTranspose(kernel_size = 3,
                        filters = 128,
                        strides = (2, 2),
                        padding = "same"
                        )(net)
  net = Conv2D(kernel_size = 3,
                filters = 128,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = ReLU()(net)

  net = Conv2DTranspose(kernel_size = 3,
                        filters = 64,
                        strides = (2, 2),
                        padding = "same"
                        )(net)
  net = Conv2D(kernel_size = 3,
                filters = 64,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = LeakyReLU()(net)

  # output layer 영역
  net = Conv2D(kernel_size = 7,
                filters = 3,
                strides = 1,
                padding = "same"
                )(net)

  # generator 완성
  model = Model(inputs = input_layer,
                outputs = net,
                name = "generator"
                )
  
  # compile
  model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
  model.trainable = False # trainable 메소드는 compile할 때만 영향을 미친다

  return model

# discriminator

In [None]:
def discriminator(loss, optimizer, metrics):
  input_shape = (300, 300, 3)
  input_layer = Input(input_shape)

  net = Conv2D(kernel_size = 3,
              filters = 32,
              strides = 1,
              padding = "same"
              )(input_layer)
  net = LeakyReLU(alpha=0.2)(net)

  net = Conv2D(kernel_size = 3,
                filters = 64,
                strides = 2,
                padding = "same"
                )(net)
  net = LeakyReLU(alpha=0.2)(net)
  net = Conv2D(kernel_size = 3,
                filters = 128,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = LeakyReLU(alpha=0.2)(net)

  net = Conv2D(kernel_size = 3,
                filters = 128,
                strides = 2,
                padding = "same"
                )(net)
  net = LeakyReLU(alpha=0.2)(net)
  net = Conv2D(kernel_size = 3,
                filters = 256,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = LeakyReLU(alpha=0.2)(net)

  net = Conv2D(kernel_size = 3,
                filters = 256,
                strides = 1,
                padding = "same"
                )(net)
  net = BatchNormalization()(net)
  net = LeakyReLU(alpha=0.2)(net)

  net = Conv2D(kernel_size = 3,
                filters = 1,
                strides = 1,
                padding = "same"
                )(net)

  # disciminator 완성
  model = Model(inputs = input_layer,
                outputs = net,
                name = "discriminator"
                )
  
  # compile
  model.compile(loss = loss, optimizer = optimizer, metrics = metrics)
  model.trainable = False # trainable 메소드는 compile할 때만 영향을 미친다

  return model

# CartoonGAN

In [None]:
def CartoonGAN(generator, discriminator, loss, optimizer, metrics):
  """
  ------------------------------------
  함수 설명
  ------------------------------------
  CartoonGAN은 generator, discriminator가 합쳐진 모델
  ------------------------------------
  """
  # CartoonGAN 완성
  CartoonGAN = Sequential()
  CartoonGAN.add(generator)
  CartoonGAN.add(discriminator)

  # compile
  CartoonGAN.compile(loss = loss, optimizer = optimizer, metrics = metrics)
  CartoonGAN.trainable = False # trainable 메소드는 compile할 때만 영향을 미친다

  return CartoonGAN

# generator_web

In [None]:
def generator_web():
  """
  -------------------------------
  함수 설명
  -------------------------------
  미리 학습시켜놓은 가중치를 피클로 불러와서 사용
  -------------------------------
  """
  
  return