In [151]:
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Conv2D, Input, BatchNormalization, Dropout, Activation, Flatten, Dense, Reshape, UpSampling2D
import tensorflow_datasets as tfds
import numpy as np
import os

In [74]:
print(tf.__version__)

2.0.0


In [134]:
def main():
    # disctiminator
    d_inputs = Input((512, 512, 3))
    
    d_x = Conv2D(filters=10, kernel_size=4, strides=2, padding='same', name='discriminator_conv_1')(d_inputs)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    
    d_x = Dropout(rate=0.4)(d_x)

    d_x = Conv2D(filters=20, kernel_size=4, strides=2, padding='same', name='discriminator_conv_2')(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    d_x = Dropout(rate=0.4)(d_x)

    d_x = Conv2D(filters=30, kernel_size=4, strides=2, padding='same', name='discriminator_conv_3')(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    d_x = Dropout(rate=0.4)(d_x)

    d_x = Conv2D(filters=40, kernel_size=4, strides=2, padding='same', name='discriminator_conv_4')(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    d_x = Dropout(rate=0.4)(d_x)

    d_x = Conv2D(filters=50, kernel_size=4, strides=2, padding='same', name='discriminator_conv_5')(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    d_x = Dropout(rate=0.4)(d_x)
    
    d_x = Conv2D(filters=60, kernel_size=4, strides=2, padding='same', name='discriminator_conv_6')(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_x = Activation('relu')(d_x)
    d_x = Dropout(rate=0.4)(d_x)

    d_x = Flatten()(d_x)
    d_x = Dense(units=10)(d_x)
    d_x = BatchNormalization(momentum=0.8)(d_x)
    d_outputs = Activation('sigmoid')(d_x)
    
    discriminator = Model(inputs=d_inputs, outputs=d_outputs)
    discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
    discriminator.summary()
    # generator
    g_inputs = Input((100, 100))
    g_x = Dense(8 * 8 * 128)(g_inputs)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)
    
    g_x = Reshape((8, 8, 128))(g_x)
    
    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=64, kernel_size=4, strides=1, padding='same')(g_x)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)
    
    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=32, kernel_size=4, strides=1, padding='same')(g_x)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)
    
    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=16, kernel_size=4, strides=1, padding='same')(g_x)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)
    
    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=8, kernel_size=4, strides=1, padding='same')(g_x)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)

    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=4, kernel_size=4, strides=1, padding='same')(g_x)
    g_x = BatchNormalization(momentum=0.8)(g_x)
    g_x = Activation('relu')(g_x)
    g_x = Dropout(rate=0.4)(g_x)

    g_x = UpSampling2D()(g_x)
    g_x = Conv2D(filters=3, kernel_size=4, strides=1, padding='same')(g_x)
    g_outputs = BatchNormalization(momentum=0.8)(g_x)
    
    discriminator.trainable = False
    generator = Model(inputs=g_inputs, outputs=g_outputs)
    generator.summary()
    
    def loss_func(discriminator):
        def calc_loss(y_real, y_pred):
            return discriminator.predict(y_pred)
        return calc_loss

    g_loss = loss_func(discriminator)

    generator.compile(loss=g_loss, optimizer='adam', metrics=['accuracy'])
    
    epochs = 10
    batch_size = 32
    for _ in epochs:
        for _ in 5:
            real = get_real_images(batch_size)
            fake = generator.predict(np.zeros(batch_size, 28, 28, 3))
            discriminator.fit(real, fake)
        generator.fit(np.zeros(100))
    
    print("done")

In [195]:
datasets, info = tfds.load('celeb_a_hq/512', with_info=True, data_dir=os.path.abspath(''))

[1mDownloading and preparing dataset celeb_a_hq (?? GiB) to /Users/ryo/Projects/papers_reproduce/wgan_gp/celeb_a_hq/512/0.1.0...[0m


AssertionError: Manual directory /Users/ryo/Projects/papers_reproduce/wgan_gp/downloads/manual/celeb_a_hq does not exist. Create it and download/extract dataset artifacts in there.

In [194]:
print(datasets['train'])

TypeError: 'CelebAHq' object is not subscriptable

In [185]:
# os.makedirs("/Users/ryo/Projects/papers_reproduce/wgan_gp/downloads/manual/celeb_a_hq", exist_ok=True)
datasets.download_and_prepare()

[1mDownloading and preparing dataset celeb_a_hq (?? GiB) to /Users/ryo/Projects/papers_reproduce/wgan_gp/celeb_a_hq/1024/0.1.0...[0m


AssertionError: Manual directory /Users/ryo/Projects/papers_reproduce/wgan_gp/downloads/manual/celeb_a_hq does not exist. Create it and download/extract dataset artifacts in there.

In [135]:
if __name__=='__main__':
    main()

Model: "model_37"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_59 (InputLayer)        [(None, 512, 512, 3)]     0         
_________________________________________________________________
discriminator_conv_1 (Conv2D (None, 256, 256, 10)      490       
_________________________________________________________________
batch_normalization_312 (Bat (None, 256, 256, 10)      40        
_________________________________________________________________
activation_297 (Activation)  (None, 256, 256, 10)      0         
_________________________________________________________________
dropout_270 (Dropout)        (None, 256, 256, 10)      0         
_________________________________________________________________
discriminator_conv_2 (Conv2D (None, 128, 128, 20)      3220      
_________________________________________________________________
batch_normalization_313 (Bat (None, 128, 128, 20)      80 

ValueError: When using data tensors as input to a model, you should specify the `steps` argument.